Unverified Commit 8b0a2d19 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Some more changes of nodes to outputs (#3315)

* Some more changes of nodes to outputs

* Chane names for OutputVector helpers to avoid API change
parent 81597f3a
...@@ -151,59 +151,58 @@ namespace ngraph ...@@ -151,59 +151,58 @@ namespace ngraph
/// Return a pointer to the node that produces the wrapped value. /// Return a pointer to the node that produces the wrapped value.
/// If no additional reshape or broadcast op was needed, simply return \p node. /// If no additional reshape or broadcast op was needed, simply return \p node.
static std::shared_ptr<Node> static std::shared_ptr<Node>
add_required_ops(const std::shared_ptr<Node>& node, add_required_ops(const Output<Node>& value,
const ngraph::Shape& node_shape_after_possible_reshaping, const ngraph::Shape& shape_after_possible_reshaping,
const ngraph::AxisSet& node_broadcast_axes, const ngraph::AxisSet& broadcast_axes,
const ngraph::Shape& node_final_shape) const ngraph::Shape& final_shape)
{ {
std::shared_ptr<Node> return_node{node}; Output<Node> return_value{value};
if (node->get_shape() != node_shape_after_possible_reshaping) if (value.get_shape() != shape_after_possible_reshaping)
{ {
// tell reshape to examine input dimensions in order // tell reshape to examine input dimensions in order
ngraph::AxisVector order = ngraph::get_default_order(node->get_shape()); ngraph::AxisVector order = ngraph::get_default_order(value.get_shape());
return_node = std::make_shared<ngraph::op::Reshape>( return_value = std::make_shared<ngraph::op::Reshape>(
return_node, order, node_shape_after_possible_reshaping); return_value, order, shape_after_possible_reshaping);
} }
if (node_final_shape != node_shape_after_possible_reshaping) if (final_shape != shape_after_possible_reshaping)
{ {
return_node = std::make_shared<ngraph::op::Broadcast>( return_value = std::make_shared<ngraph::op::Broadcast>(
return_node, node_final_shape, node_broadcast_axes); return_value, final_shape, broadcast_axes);
} }
return return_node; return return_value.get_node_shared_ptr();
} }
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>& args) numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args)
{ {
NGRAPH_CHECK(args.first); NGRAPH_CHECK(args.first.get_node());
NGRAPH_CHECK(args.second); NGRAPH_CHECK(args.second.get_node());
const ngraph::Shape& arg1_in_shape = args.first->get_shape(); const ngraph::Shape& arg1_in_shape = args.first.get_shape();
const ngraph::Shape& arg2_in_shape = args.second->get_shape(); const ngraph::Shape& arg2_in_shape = args.second.get_shape();
// Handle the trivial case... // Handle the trivial case...
if (arg1_in_shape == arg2_in_shape) if (arg1_in_shape == arg2_in_shape)
{ {
return args; return make_pair(args.first.as_single_output_node(),
args.second.as_single_output_node());
} }
Autobroadcast_plan plan = Autobroadcast_plan plan =
compute_shapes_and_broadcast_axes(arg1_in_shape, arg2_in_shape); compute_shapes_and_broadcast_axes(arg1_in_shape, arg2_in_shape);
std::shared_ptr<Node> arg1_out = auto arg1_out = add_required_ops(args.first,
add_required_ops(args.first, plan.m_arg1_shape_after_possible_reshaping,
plan.m_arg1_shape_after_possible_reshaping, plan.m_arg1_broadcast_axes,
plan.m_arg1_broadcast_axes, plan.m_final_shape);
plan.m_final_shape);
auto arg2_out = add_required_ops(args.second,
std::shared_ptr<Node> arg2_out = plan.m_arg2_shape_after_possible_reshaping,
add_required_ops(args.second, plan.m_arg2_broadcast_axes,
plan.m_arg2_shape_after_possible_reshaping, plan.m_final_shape);
plan.m_arg2_broadcast_axes,
plan.m_final_shape);
return {arg1_out, arg2_out}; return {arg1_out, arg2_out};
} }
......
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
static std::string error_str(const ngraph::Shape& shape1, const ngraph::Shape& shape2); static std::string error_str(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
}; };
/// \brief Wrap two graph nodes, if necessary, to obtain values with identical shapes, /// \brief Wrap two graph values, if necessary, to obtain values with identical shapes,
/// using NumPy's auto-broadcast rules. /// using NumPy's auto-broadcast rules.
/// ///
/// The elements in the std::pair returned by this function correspond to those supplied /// The elements in the std::pair returned by this function correspond to those supplied
...@@ -71,7 +71,7 @@ namespace ngraph ...@@ -71,7 +71,7 @@ namespace ngraph
/// ///
/// \exception ngraph::builder::autobroadcast_incompatible_shapes /// \exception ngraph::builder::autobroadcast_incompatible_shapes
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>& args); numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args);
/// Create a new \p NodeType node, and any additional nodes required to simulate NumPy-style autobroadcast /// Create a new \p NodeType node, and any additional nodes required to simulate NumPy-style autobroadcast
/// semantics. Intended for binary operations such as "Add". /// semantics. Intended for binary operations such as "Add".
...@@ -87,11 +87,10 @@ namespace ngraph ...@@ -87,11 +87,10 @@ namespace ngraph
/// \exception ngraph::builder::autobroadcast_incompatible_shapes /// \exception ngraph::builder::autobroadcast_incompatible_shapes
template <typename NodeType> template <typename NodeType>
std::shared_ptr<NodeType> std::shared_ptr<NodeType>
make_with_numpy_broadcast(const std::shared_ptr<Node>& operand1_reshapeable, make_with_numpy_broadcast(const Output<Node>& operand1_reshapeable,
const std::shared_ptr<Node>& operand2_reshapeable) const Output<Node>& operand2_reshapeable)
{ {
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> shaped_op1_op2 = auto shaped_op1_op2 = numpy_broadcast({operand1_reshapeable, operand2_reshapeable});
numpy_broadcast({operand1_reshapeable, operand2_reshapeable});
return std::make_shared<NodeType>(shaped_op1_op2.first, shaped_op1_op2.second); return std::make_shared<NodeType>(shaped_op1_op2.first, shaped_op1_op2.second);
} }
...@@ -112,16 +111,13 @@ namespace ngraph ...@@ -112,16 +111,13 @@ namespace ngraph
/// ///
/// \exception ngraph::builder::autobroadcast_incompatible_shapes /// \exception ngraph::builder::autobroadcast_incompatible_shapes
template <typename NodeType> template <typename NodeType>
std::shared_ptr<NodeType> std::shared_ptr<Node> make_with_numpy_broadcast(const Output<Node>& operand1,
make_with_numpy_broadcast(const std::shared_ptr<Node>& operand1, const Output<Node>& operand2_reshapeable,
const std::shared_ptr<Node>& operand2_reshapeable, const Output<Node>& operand3_reshapeable)
const std::shared_ptr<Node>& operand3_reshapeable)
{ {
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> shaped_op2_op3 = auto shaped_op2_op3 = numpy_broadcast({operand2_reshapeable, operand3_reshapeable});
numpy_broadcast({operand2_reshapeable, operand3_reshapeable});
return std::make_shared<NodeType>( return std::make_shared<NodeType>(
operand1, shaped_op2_op3.first, shaped_op2_op3.second); operand1, shaped_op2_op3.first, shaped_op2_op3.second);
} }
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
...@@ -37,9 +37,9 @@ namespace ngraph ...@@ -37,9 +37,9 @@ namespace ngraph
namespace builder namespace builder
{ {
std::shared_ptr<Node> numpy_transpose(const std::shared_ptr<Node>& node, AxisVector order) std::shared_ptr<Node> numpy_transpose(const Output<Node>& value, AxisVector order)
{ {
auto in_shape = node->get_shape(); auto in_shape = value.get_shape();
// default, reverse the order of the axes // default, reverse the order of the axes
if (order.size() == 0) if (order.size() == 0)
{ {
...@@ -74,7 +74,7 @@ namespace ngraph ...@@ -74,7 +74,7 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]); out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order // do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(node, order, out_shape); return std::make_shared<ngraph::op::Reshape>(value, order, out_shape);
} }
} // namespace builder } // namespace builder
......
...@@ -45,7 +45,6 @@ namespace ngraph ...@@ -45,7 +45,6 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_{n-1},\dots,d_0)]\textit{ or }E[d_{order[0]},\dots,d_{order[n-1]}]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the axes reordered via Numpy Transpose rules | /// | \f$E[d_{n-1},\dots,d_0)]\textit{ or }E[d_{order[0]},\dots,d_{order[n-1]}]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the axes reordered via Numpy Transpose rules |
std::shared_ptr<Node> numpy_transpose(const std::shared_ptr<Node>& node, std::shared_ptr<Node> numpy_transpose(const Output<Node>& value, AxisVector order = {});
AxisVector order = {});
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Clamp::Clamp(const shared_ptr<Node>& data, const double min, const double max) const string op::Clamp::type_name{"Clamp"};
: FusedOp("Clamp", {data})
op::Clamp::Clamp(const Output<Node>& data, const double min, const double max)
: FusedOp({data})
, m_min{min} , m_min{min}
, m_max{max} , m_max{max}
{ {
......
...@@ -32,12 +32,15 @@ namespace ngraph ...@@ -32,12 +32,15 @@ namespace ngraph
class Clamp : public ngraph::op::util::FusedOp class Clamp : public ngraph::op::util::FusedOp
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Clamp node. /// \brief Constructs a Clamp node.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range /// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range /// \param max - the upper bound of the <min;max> range
Clamp(const std::shared_ptr<ngraph::Node>& data, const double min, const double max); Clamp(const Output<ngraph::Node>& data, const double min, const double max);
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
......
...@@ -74,7 +74,8 @@ NodeVector op::Gemm::decompose_op() const ...@@ -74,7 +74,8 @@ NodeVector op::Gemm::decompose_op() const
C = std::make_shared<ngraph::op::Multiply>(beta_node, C); C = std::make_shared<ngraph::op::Multiply>(beta_node, C);
// alpha * A' * B' + beta * C // alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = ngraph::op::numpy_style_broadcast({a_dot_b, C}); OutputVector broadcasted_nodes =
ngraph::op::numpy_style_broadcast_values(OutputVector{a_dot_b, C});
// The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor. // The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor.
// Numpy style broadcast is bidirectional, so we only use the second output from broadcasting. // Numpy style broadcast is bidirectional, so we only use the second output from broadcasting.
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))}; return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))};
......
...@@ -37,34 +37,34 @@ op::PRelu::PRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& slope) ...@@ -37,34 +37,34 @@ op::PRelu::PRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& slope)
NodeVector op::PRelu::decompose_op() const NodeVector op::PRelu::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto data_shape = data->get_shape(); auto data_shape = data.get_shape();
auto slope = get_argument(1); auto slope = input(1).get_source_output();
auto slope_shape = slope->get_shape(); auto slope_shape = slope.get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1)) if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{ {
auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0)); auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it); auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index); slope = make_broadcast_node(slope, data.get_shape(), index);
} }
else if (data_shape != slope_shape) else if (data_shape != slope_shape)
{ {
slope = numpy_style_broadcast({slope, data})[0]; slope = numpy_style_broadcast_values({slope, data})[0];
} }
// x < 0 => f(x) = x * slope // x < 0 => f(x) = x * slope
// x >= 0 => f(x) = x // x >= 0 => f(x) = x
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0}); data.get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape()); zero_node = make_broadcast_node(zero_node, data.get_shape());
std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>( std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node), data->get_element_type()); std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>( std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, zero_node), data->get_element_type()); std::make_shared<ngraph::op::Greater>(data, zero_node), data.get_element_type());
slope = negative_map * slope + positive_map; slope = negative_map * slope + positive_map;
......
...@@ -31,12 +31,12 @@ op::ScaleShift::ScaleShift(const std::shared_ptr<ngraph::Node>& data, ...@@ -31,12 +31,12 @@ op::ScaleShift::ScaleShift(const std::shared_ptr<ngraph::Node>& data,
NodeVector op::ScaleShift::decompose_op() const NodeVector op::ScaleShift::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto scale = get_argument(1); auto scale = input(1).get_source_output();
auto shift = get_argument(2); auto shift = input(2).get_source_output();
// broadcast all data // broadcast all data
auto broadcasted_nodes = numpy_style_broadcast({data, scale, shift}); auto broadcasted_nodes = numpy_style_broadcast_values({data, scale, shift});
data = broadcasted_nodes[0]; data = broadcasted_nodes[0];
scale = broadcasted_nodes[1]; scale = broadcasted_nodes[1];
shift = broadcasted_nodes[2]; shift = broadcasted_nodes[2];
......
...@@ -32,10 +32,10 @@ op::SquaredDifference::SquaredDifference(const shared_ptr<Node>& x1, const share ...@@ -32,10 +32,10 @@ op::SquaredDifference::SquaredDifference(const shared_ptr<Node>& x1, const share
NodeVector op::SquaredDifference::decompose_op() const NodeVector op::SquaredDifference::decompose_op() const
{ {
const auto x1 = get_argument(0); const auto x1 = input(0).get_source_output();
const auto x2 = get_argument(1); const auto x2 = input(1).get_source_output();
const auto broadcasted = numpy_style_broadcast({x1, x2}); const auto broadcasted = numpy_style_broadcast_values({x1, x2});
const auto difference = broadcasted.at(0) - broadcasted.at(1); const auto difference = broadcasted.at(0) - broadcasted.at(1);
......
...@@ -104,6 +104,19 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>> ...@@ -104,6 +104,19 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
return get_numpy_broadcast_shapes(input_shapes); return get_numpy_broadcast_shapes(input_shapes);
} }
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::OutputVector& values)
{
std::vector<ngraph::Shape> input_shapes;
for (const auto& input : values)
{
input_shapes.push_back(input.get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
}
/// \brief Broadcast input node. /// \brief Broadcast input node.
/// ///
/// \note The source shape does not have to be the actual shape of input node. However /// \note The source shape does not have to be the actual shape of input node. However
...@@ -112,21 +125,21 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>> ...@@ -112,21 +125,21 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
/// The ranks of source_shape and output_shape must be equal. This means that the /// The ranks of source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation. /// source_shape has to be padded with ones for this operation.
/// ///
/// \param[in] node The input Node to be broadcast. /// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape. /// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node. /// \param[in] source_shape The source shape from which we want to broadcast input node.
/// ///
/// \return The broadcasted Node. /// \return The broadcasted Node.
/// ///
static std::shared_ptr<ngraph::Node> static std::shared_ptr<ngraph::Node>
broadcast_node_numpy_style(const std::shared_ptr<ngraph::Node>& node, broadcast_node_numpy_style(const ngraph::Output<ngraph::Node>& value,
const ngraph::Shape& output_shape, const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape) const ngraph::Shape& source_shape)
{ {
// If node already has the required shape, return original node // If node already has the required shape, return original node
if (output_shape == node->get_shape()) if (output_shape == value.get_shape())
{ {
return node; return value.as_single_output_node();
} }
if (source_shape.size() != output_shape.size()) if (source_shape.size() != output_shape.size())
...@@ -153,16 +166,35 @@ static std::shared_ptr<ngraph::Node> ...@@ -153,16 +166,35 @@ static std::shared_ptr<ngraph::Node>
} }
// Remove axes which have length of 1 from source shape // Remove axes which have length of 1 from source shape
auto broadcasted_node = std::make_shared<ngraph::op::Reshape>( ngraph::Output<ngraph::Node> broadcasted_value = std::make_shared<ngraph::op::Reshape>(
node, ngraph::get_default_order(node->get_shape()), squeezed_shape); value, ngraph::get_default_order(value.get_shape()), squeezed_shape);
return std::make_shared<ngraph::op::Broadcast>(broadcasted_node, output_shape, broadcast_axes); return std::make_shared<ngraph::op::Broadcast>(broadcasted_value, output_shape, broadcast_axes);
} }
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
OutputVector numpy_style_broadcast_values(const OutputVector& values)
{
if (values.size() <= 1)
{
return values;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(values);
OutputVector broadcasted_inputs;
for (std::size_t i = 0; i < values.size(); ++i)
{
broadcasted_inputs.push_back(broadcast_node_numpy_style(
values[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
NodeVector numpy_style_broadcast(const NodeVector& inputs) NodeVector numpy_style_broadcast(const NodeVector& inputs)
{ {
if (inputs.size() <= 1) if (inputs.size() <= 1)
...@@ -176,19 +208,17 @@ namespace ngraph ...@@ -176,19 +208,17 @@ namespace ngraph
NodeVector broadcasted_inputs; NodeVector broadcasted_inputs;
for (std::size_t i = 0; i < inputs.size(); ++i) for (std::size_t i = 0; i < inputs.size(); ++i)
{ {
const std::shared_ptr<ngraph::Node> input_node = inputs[i];
broadcasted_inputs.push_back(broadcast_node_numpy_style( broadcasted_inputs.push_back(broadcast_node_numpy_style(
inputs[i], bcast_shapes.first, bcast_shapes.second[i])); inputs[i], bcast_shapes.first, bcast_shapes.second[i]));
} }
return broadcasted_inputs; return broadcasted_inputs;
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node> numpy_style_broadcast(const Output<ngraph::Node>& value,
numpy_style_broadcast(const std::shared_ptr<ngraph::Node>& input_node, const Shape& shape)
const Shape& shape)
{ {
auto bcast_shape = get_numpy_broadcast_shapes({input_node->get_shape(), shape}); auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
return broadcast_node_numpy_style(input_node, bcast_shape.first, bcast_shape.second[0]); return broadcast_node_numpy_style(value, bcast_shape.first, bcast_shape.second[0]);
} }
NodeVector NodeVector
...@@ -227,6 +257,42 @@ namespace ngraph ...@@ -227,6 +257,42 @@ namespace ngraph
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)}; broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
} }
OutputVector
numpy_style_broadcast_values_for_matmul_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shapes(
{Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_output_shape.insert(std::end(right_output_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_full_shape.insert(std::end(right_full_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcast_node_numpy_style(left, left_output_shape, left_full_shape),
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
NodeVector NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left, legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right, const std::shared_ptr<ngraph::Node>& right,
...@@ -288,6 +354,67 @@ namespace ngraph ...@@ -288,6 +354,67 @@ namespace ngraph
return {left, broadcast_right}; return {left, broadcast_right};
} }
OutputVector
legacy_style_broadcast_values_for_binary_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right,
size_t start_match_axis)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
{
return {left, right};
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
{
if (new_right_shape[dimension] == 1)
{
new_right_shape.pop_back();
}
else
{
break;
}
}
// Find first dimensions at front with length different from 1
std::size_t num_ones = 0;
for (std::size_t dimension : new_right_shape)
{
if (dimension == 1)
{
++num_ones;
}
else
{
break;
}
}
// Remove dimensions with length=1 from front
new_right_shape.erase(std::begin(new_right_shape),
std::next(std::begin(new_right_shape), num_ones));
auto reshape_right = std::make_shared<ngraph::op::Reshape>(
right, ngraph::get_default_order(right_shape), new_right_shape);
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
auto broadcast_right = std::make_shared<ngraph::op::Broadcast>(
reshape_right,
left_shape,
calculate_broadcast_axes(left_shape, new_right_shape, start_match_axis));
return {left, broadcast_right};
}
AxisSet calculate_broadcast_axes(const Shape& output_shape, AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape, const Shape& input_shape,
std::size_t start_match_axis) std::size_t start_match_axis)
......
...@@ -32,19 +32,45 @@ namespace ngraph ...@@ -32,19 +32,45 @@ namespace ngraph
/// \param inputs Original list of inputs /// \param inputs Original list of inputs
/// ///
/// \return Numpy-style broadcasted list of nodes. /// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(const NodeVector& inputs); NodeVector numpy_style_broadcast(const NodeVector& inputs)
NGRAPH_DEPRECATED("Replace with numpy_style_value_broadcast");
/// \brief Cast shape of an input node to the requested output shape using NumPy's broadcasting rules /// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param values Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
OutputVector numpy_style_broadcast_values(const OutputVector& values);
/// \brief Cast shape of an output to the requested output shape using NumPy's broadcasting rules
/// ///
/// \param input_node original input node /// \param value original value
/// \param shape requested output shape /// \param shape requested output shape
/// ///
/// \return Broadcast node. /// \return Broadcast output.
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> numpy_style_broadcast(const Output<Node>& value, const Shape& shape);
numpy_style_broadcast(const std::shared_ptr<ngraph::Node>& input_node,
const Shape& shape); /// \brief Cast shape of two outputs to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector legacy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right,
size_t start_match_axis)
NGRAPH_DEPRECATED("Replace with legacy_style_value_broadcast_for_binary_operation");
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation. /// \brief Cast shape of two outputs to make them compatible for an element-wise binary operation.
/// ///
/// If necessary the right-hand-side argument will be broadcast to match the shape /// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is /// of left-hand-side argument. The starting of the mutually equal shape is
...@@ -59,10 +85,9 @@ namespace ngraph ...@@ -59,10 +85,9 @@ namespace ngraph
/// \param start_match_axis position in shape denoting start of the mutually equal shape /// \param start_match_axis position in shape denoting start of the mutually equal shape
/// ///
/// \return Left and right node after broadcasting. /// \return Left and right node after broadcasting.
NodeVector OutputVector legacy_style_broadcast_values_for_binary_operation(const Output<Node>& left,
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left, const Output<Node>& right,
const std::shared_ptr<ngraph::Node>& right, size_t start_match_axis);
std::size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication. /// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
/// ///
...@@ -76,9 +101,24 @@ namespace ngraph ...@@ -76,9 +101,24 @@ namespace ngraph
/// ///
/// \return The vector containing both nodes broadcasted. /// \return The vector containing both nodes broadcasted.
/// ///
NodeVector NodeVector numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<Node>& left,
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left, const std::shared_ptr<Node>& right)
const std::shared_ptr<ngraph::Node>& right); NGRAPH_DEPRECATED("Replace with numpy_style_broadcast_value_for_matmul_operation.");
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul` operation
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html)
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Generate a list of broadcast axes. /// \brief Generate a list of broadcast axes.
/// ///
...@@ -118,22 +158,21 @@ namespace ngraph ...@@ -118,22 +158,21 @@ namespace ngraph
output_shape, input_shape, output_shape.size() - input_shape.size()); output_shape, input_shape, output_shape.size() - input_shape.size());
} }
inline std::shared_ptr<ngraph::Node> inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& output,
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape) Shape new_shape)
{ {
return std::make_shared<ngraph::op::Broadcast>( return std::make_shared<op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape())); output, new_shape, calculate_broadcast_axes(new_shape, output.get_shape()));
} }
inline std::shared_ptr<ngraph::Node> inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, const Shape& new_shape,
const ngraph::Shape& new_shape, std::size_t start_match_axis)
std::size_t start_match_axis)
{ {
return std::make_shared<ngraph::op::Broadcast>( return std::make_shared<op::Broadcast>(
node, value,
new_shape, new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis)); calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
} }
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph
...@@ -69,32 +69,29 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size ...@@ -69,32 +69,29 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
return afunc; return afunc;
} }
shared_ptr<Node> op::util::RNNCellBase::add(const shared_ptr<Node>& lhs, shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs)
const shared_ptr<Node>& rhs)
{ {
auto args = op::numpy_style_broadcast({lhs, rhs}); auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))}; return {make_shared<op::Add>(args.at(0), args.at(1))};
} }
shared_ptr<Node> op::util::RNNCellBase::sub(const shared_ptr<Node>& lhs, shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs)
const shared_ptr<Node>& rhs)
{ {
auto args = op::numpy_style_broadcast({lhs, rhs}); auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))}; return {make_shared<op::Subtract>(args.at(0), args.at(1))};
} }
shared_ptr<Node> op::util::RNNCellBase::mul(const shared_ptr<Node>& lhs, shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs)
const shared_ptr<Node>& rhs)
{ {
auto args = op::numpy_style_broadcast({lhs, rhs}); auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))}; return {make_shared<op::Multiply>(args.at(0), args.at(1))};
} }
shared_ptr<Node> op::util::RNNCellBase::clip(const shared_ptr<Node>& data) const shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const
{ {
if (m_clip == 0.f) if (m_clip == 0.f)
{ {
return data; return data.as_single_output_node();
} }
return make_shared<op::Clamp>(data, -m_clip, m_clip); return make_shared<op::Clamp>(data, -m_clip, m_clip);
......
...@@ -81,8 +81,7 @@ namespace ngraph ...@@ -81,8 +81,7 @@ namespace ngraph
/// ///
/// \return Node with element-wise add operation. /// \return Node with element-wise add operation.
/// ///
static std::shared_ptr<Node> add(const std::shared_ptr<Node>& lhs, static std::shared_ptr<Node> add(const Output<Node>& lhs, const Output<Node>& rhs);
const std::shared_ptr<Node>& rhs);
/// ///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting. /// \brief Creates node with element-wise subtract operation with numpy broadcasting.
/// ///
...@@ -91,8 +90,7 @@ namespace ngraph ...@@ -91,8 +90,7 @@ namespace ngraph
/// ///
/// \return Node with element-wise subtract operation. /// \return Node with element-wise subtract operation.
/// ///
static std::shared_ptr<Node> sub(const std::shared_ptr<Node>& lhs, static std::shared_ptr<Node> sub(const Output<Node>& lhs, const Output<Node>& rhs);
const std::shared_ptr<Node>& rhs);
/// ///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting. /// \brief Creates node with element-wise multiply operation with numpy broadcasting.
/// ///
...@@ -101,8 +99,7 @@ namespace ngraph ...@@ -101,8 +99,7 @@ namespace ngraph
/// ///
/// \return Node with element-wise multiply operation. /// \return Node with element-wise multiply operation.
/// ///
static std::shared_ptr<Node> mul(const std::shared_ptr<Node>& lhs, static std::shared_ptr<Node> mul(const Output<Node>& lhs, const Output<Node>& rhs);
const std::shared_ptr<Node>& rhs);
/// ///
/// \brief Creates node with element-wise clip operation with numpy broadcasting. /// \brief Creates node with element-wise clip operation with numpy broadcasting.
/// ///
...@@ -110,7 +107,7 @@ namespace ngraph ...@@ -110,7 +107,7 @@ namespace ngraph
/// ///
/// \return Node with element-wise clip operation. /// \return Node with element-wise clip operation.
/// ///
std::shared_ptr<Node> clip(const std::shared_ptr<Node>& data) const; std::shared_ptr<Node> clip(const Output<Node>& data) const;
private: private:
const std::size_t m_hidden_size; const std::size_t m_hidden_size;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment