Unverified Commit 8b0a2d19 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Some more changes of nodes to outputs (#3315)

* Some more changes of nodes to outputs

* Chane names for OutputVector helpers to avoid API change
parent 81597f3a
......@@ -151,59 +151,58 @@ namespace ngraph
/// Return a pointer to the node that produces the wrapped value.
/// If no additional reshape or broadcast op was needed, simply return \p node.
static std::shared_ptr<Node>
add_required_ops(const std::shared_ptr<Node>& node,
const ngraph::Shape& node_shape_after_possible_reshaping,
const ngraph::AxisSet& node_broadcast_axes,
const ngraph::Shape& node_final_shape)
add_required_ops(const Output<Node>& value,
const ngraph::Shape& shape_after_possible_reshaping,
const ngraph::AxisSet& broadcast_axes,
const ngraph::Shape& final_shape)
{
std::shared_ptr<Node> return_node{node};
Output<Node> return_value{value};
if (node->get_shape() != node_shape_after_possible_reshaping)
if (value.get_shape() != shape_after_possible_reshaping)
{
// tell reshape to examine input dimensions in order
ngraph::AxisVector order = ngraph::get_default_order(node->get_shape());
return_node = std::make_shared<ngraph::op::Reshape>(
return_node, order, node_shape_after_possible_reshaping);
ngraph::AxisVector order = ngraph::get_default_order(value.get_shape());
return_value = std::make_shared<ngraph::op::Reshape>(
return_value, order, shape_after_possible_reshaping);
}
if (node_final_shape != node_shape_after_possible_reshaping)
if (final_shape != shape_after_possible_reshaping)
{
return_node = std::make_shared<ngraph::op::Broadcast>(
return_node, node_final_shape, node_broadcast_axes);
return_value = std::make_shared<ngraph::op::Broadcast>(
return_value, final_shape, broadcast_axes);
}
return return_node;
return return_value.get_node_shared_ptr();
}
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>& args)
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args)
{
NGRAPH_CHECK(args.first);
NGRAPH_CHECK(args.second);
NGRAPH_CHECK(args.first.get_node());
NGRAPH_CHECK(args.second.get_node());
const ngraph::Shape& arg1_in_shape = args.first->get_shape();
const ngraph::Shape& arg2_in_shape = args.second->get_shape();
const ngraph::Shape& arg1_in_shape = args.first.get_shape();
const ngraph::Shape& arg2_in_shape = args.second.get_shape();
// Handle the trivial case...
if (arg1_in_shape == arg2_in_shape)
{
return args;
return make_pair(args.first.as_single_output_node(),
args.second.as_single_output_node());
}
Autobroadcast_plan plan =
compute_shapes_and_broadcast_axes(arg1_in_shape, arg2_in_shape);
std::shared_ptr<Node> arg1_out =
add_required_ops(args.first,
plan.m_arg1_shape_after_possible_reshaping,
plan.m_arg1_broadcast_axes,
plan.m_final_shape);
std::shared_ptr<Node> arg2_out =
add_required_ops(args.second,
plan.m_arg2_shape_after_possible_reshaping,
plan.m_arg2_broadcast_axes,
plan.m_final_shape);
auto arg1_out = add_required_ops(args.first,
plan.m_arg1_shape_after_possible_reshaping,
plan.m_arg1_broadcast_axes,
plan.m_final_shape);
auto arg2_out = add_required_ops(args.second,
plan.m_arg2_shape_after_possible_reshaping,
plan.m_arg2_broadcast_axes,
plan.m_final_shape);
return {arg1_out, arg2_out};
}
......
......@@ -42,7 +42,7 @@ namespace ngraph
static std::string error_str(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
};
/// \brief Wrap two graph nodes, if necessary, to obtain values with identical shapes,
/// \brief Wrap two graph values, if necessary, to obtain values with identical shapes,
/// using NumPy's auto-broadcast rules.
///
/// The elements in the std::pair returned by this function correspond to those supplied
......@@ -71,7 +71,7 @@ namespace ngraph
///
/// \exception ngraph::builder::autobroadcast_incompatible_shapes
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>& args);
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args);
/// Create a new \p NodeType node, and any additional nodes required to simulate NumPy-style autobroadcast
/// semantics. Intended for binary operations such as "Add".
......@@ -87,11 +87,10 @@ namespace ngraph
/// \exception ngraph::builder::autobroadcast_incompatible_shapes
template <typename NodeType>
std::shared_ptr<NodeType>
make_with_numpy_broadcast(const std::shared_ptr<Node>& operand1_reshapeable,
const std::shared_ptr<Node>& operand2_reshapeable)
make_with_numpy_broadcast(const Output<Node>& operand1_reshapeable,
const Output<Node>& operand2_reshapeable)
{
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> shaped_op1_op2 =
numpy_broadcast({operand1_reshapeable, operand2_reshapeable});
auto shaped_op1_op2 = numpy_broadcast({operand1_reshapeable, operand2_reshapeable});
return std::make_shared<NodeType>(shaped_op1_op2.first, shaped_op1_op2.second);
}
......@@ -112,16 +111,13 @@ namespace ngraph
///
/// \exception ngraph::builder::autobroadcast_incompatible_shapes
template <typename NodeType>
std::shared_ptr<NodeType>
make_with_numpy_broadcast(const std::shared_ptr<Node>& operand1,
const std::shared_ptr<Node>& operand2_reshapeable,
const std::shared_ptr<Node>& operand3_reshapeable)
std::shared_ptr<Node> make_with_numpy_broadcast(const Output<Node>& operand1,
const Output<Node>& operand2_reshapeable,
const Output<Node>& operand3_reshapeable)
{
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> shaped_op2_op3 =
numpy_broadcast({operand2_reshapeable, operand3_reshapeable});
auto shaped_op2_op3 = numpy_broadcast({operand2_reshapeable, operand3_reshapeable});
return std::make_shared<NodeType>(
operand1, shaped_op2_op3.first, shaped_op2_op3.second);
}
} // namespace builder
} // namespace ngraph
......@@ -37,9 +37,9 @@ namespace ngraph
namespace builder
{
std::shared_ptr<Node> numpy_transpose(const std::shared_ptr<Node>& node, AxisVector order)
std::shared_ptr<Node> numpy_transpose(const Output<Node>& value, AxisVector order)
{
auto in_shape = node->get_shape();
auto in_shape = value.get_shape();
// default, reverse the order of the axes
if (order.size() == 0)
{
......@@ -74,7 +74,7 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(node, order, out_shape);
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape);
}
} // namespace builder
......
......@@ -45,7 +45,6 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_{n-1},\dots,d_0)]\textit{ or }E[d_{order[0]},\dots,d_{order[n-1]}]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the axes reordered via Numpy Transpose rules |
std::shared_ptr<Node> numpy_transpose(const std::shared_ptr<Node>& node,
AxisVector order = {});
std::shared_ptr<Node> numpy_transpose(const Output<Node>& value, AxisVector order = {});
} // namespace builder
} // namespace ngraph
......@@ -22,8 +22,10 @@
using namespace std;
using namespace ngraph;
op::Clamp::Clamp(const shared_ptr<Node>& data, const double min, const double max)
: FusedOp("Clamp", {data})
const string op::Clamp::type_name{"Clamp"};
op::Clamp::Clamp(const Output<Node>& data, const double min, const double max)
: FusedOp({data})
, m_min{min}
, m_max{max}
{
......
......@@ -32,12 +32,15 @@ namespace ngraph
class Clamp : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Clamp node.
///
/// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range
Clamp(const std::shared_ptr<ngraph::Node>& data, const double min, const double max);
Clamp(const Output<ngraph::Node>& data, const double min, const double max);
void pre_validate_and_infer_types() override;
......
......@@ -74,7 +74,8 @@ NodeVector op::Gemm::decompose_op() const
C = std::make_shared<ngraph::op::Multiply>(beta_node, C);
// alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = ngraph::op::numpy_style_broadcast({a_dot_b, C});
OutputVector broadcasted_nodes =
ngraph::op::numpy_style_broadcast_values(OutputVector{a_dot_b, C});
// The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor.
// Numpy style broadcast is bidirectional, so we only use the second output from broadcasting.
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))};
......
......@@ -37,34 +37,34 @@ op::PRelu::PRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& slope)
NodeVector op::PRelu::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape();
auto slope = get_argument(1);
auto slope_shape = slope->get_shape();
auto data = input(0).get_source_output();
auto data_shape = data.get_shape();
auto slope = input(1).get_source_output();
auto slope_shape = slope.get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
slope = make_broadcast_node(slope, data.get_shape(), index);
}
else if (data_shape != slope_shape)
{
slope = numpy_style_broadcast({slope, data})[0];
slope = numpy_style_broadcast_values({slope, data})[0];
}
// x < 0 => f(x) = x * slope
// x >= 0 => f(x) = x
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
data.get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data.get_shape());
std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node), data->get_element_type());
std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, zero_node), data->get_element_type());
std::make_shared<ngraph::op::Greater>(data, zero_node), data.get_element_type());
slope = negative_map * slope + positive_map;
......
......@@ -31,12 +31,12 @@ op::ScaleShift::ScaleShift(const std::shared_ptr<ngraph::Node>& data,
NodeVector op::ScaleShift::decompose_op() const
{
auto data = get_argument(0);
auto scale = get_argument(1);
auto shift = get_argument(2);
auto data = input(0).get_source_output();
auto scale = input(1).get_source_output();
auto shift = input(2).get_source_output();
// broadcast all data
auto broadcasted_nodes = numpy_style_broadcast({data, scale, shift});
auto broadcasted_nodes = numpy_style_broadcast_values({data, scale, shift});
data = broadcasted_nodes[0];
scale = broadcasted_nodes[1];
shift = broadcasted_nodes[2];
......
......@@ -32,10 +32,10 @@ op::SquaredDifference::SquaredDifference(const shared_ptr<Node>& x1, const share
NodeVector op::SquaredDifference::decompose_op() const
{
const auto x1 = get_argument(0);
const auto x2 = get_argument(1);
const auto x1 = input(0).get_source_output();
const auto x2 = input(1).get_source_output();
const auto broadcasted = numpy_style_broadcast({x1, x2});
const auto broadcasted = numpy_style_broadcast_values({x1, x2});
const auto difference = broadcasted.at(0) - broadcasted.at(1);
......
......@@ -104,6 +104,19 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
return get_numpy_broadcast_shapes(input_shapes);
}
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::OutputVector& values)
{
std::vector<ngraph::Shape> input_shapes;
for (const auto& input : values)
{
input_shapes.push_back(input.get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
......@@ -112,21 +125,21 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
/// The ranks of source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation.
///
/// \param[in] node The input Node to be broadcast.
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node.
///
/// \return The broadcasted Node.
///
static std::shared_ptr<ngraph::Node>
broadcast_node_numpy_style(const std::shared_ptr<ngraph::Node>& node,
broadcast_node_numpy_style(const ngraph::Output<ngraph::Node>& value,
const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape)
{
// If node already has the required shape, return original node
if (output_shape == node->get_shape())
if (output_shape == value.get_shape())
{
return node;
return value.as_single_output_node();
}
if (source_shape.size() != output_shape.size())
......@@ -153,16 +166,35 @@ static std::shared_ptr<ngraph::Node>
}
// Remove axes which have length of 1 from source shape
auto broadcasted_node = std::make_shared<ngraph::op::Reshape>(
node, ngraph::get_default_order(node->get_shape()), squeezed_shape);
ngraph::Output<ngraph::Node> broadcasted_value = std::make_shared<ngraph::op::Reshape>(
value, ngraph::get_default_order(value.get_shape()), squeezed_shape);
return std::make_shared<ngraph::op::Broadcast>(broadcasted_node, output_shape, broadcast_axes);
return std::make_shared<ngraph::op::Broadcast>(broadcasted_value, output_shape, broadcast_axes);
}
namespace ngraph
{
namespace op
{
OutputVector numpy_style_broadcast_values(const OutputVector& values)
{
if (values.size() <= 1)
{
return values;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(values);
OutputVector broadcasted_inputs;
for (std::size_t i = 0; i < values.size(); ++i)
{
broadcasted_inputs.push_back(broadcast_node_numpy_style(
values[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
NodeVector numpy_style_broadcast(const NodeVector& inputs)
{
if (inputs.size() <= 1)
......@@ -176,19 +208,17 @@ namespace ngraph
NodeVector broadcasted_inputs;
for (std::size_t i = 0; i < inputs.size(); ++i)
{
const std::shared_ptr<ngraph::Node> input_node = inputs[i];
broadcasted_inputs.push_back(broadcast_node_numpy_style(
inputs[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
std::shared_ptr<ngraph::Node>
numpy_style_broadcast(const std::shared_ptr<ngraph::Node>& input_node,
const Shape& shape)
std::shared_ptr<ngraph::Node> numpy_style_broadcast(const Output<ngraph::Node>& value,
const Shape& shape)
{
auto bcast_shape = get_numpy_broadcast_shapes({input_node->get_shape(), shape});
return broadcast_node_numpy_style(input_node, bcast_shape.first, bcast_shape.second[0]);
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
return broadcast_node_numpy_style(value, bcast_shape.first, bcast_shape.second[0]);
}
NodeVector
......@@ -227,6 +257,42 @@ namespace ngraph
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
OutputVector
numpy_style_broadcast_values_for_matmul_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shapes(
{Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_output_shape.insert(std::end(right_output_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_full_shape.insert(std::end(right_full_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcast_node_numpy_style(left, left_output_shape, left_full_shape),
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
......@@ -288,6 +354,67 @@ namespace ngraph
return {left, broadcast_right};
}
OutputVector
legacy_style_broadcast_values_for_binary_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right,
size_t start_match_axis)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
{
return {left, right};
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
{
if (new_right_shape[dimension] == 1)
{
new_right_shape.pop_back();
}
else
{
break;
}
}
// Find first dimensions at front with length different from 1
std::size_t num_ones = 0;
for (std::size_t dimension : new_right_shape)
{
if (dimension == 1)
{
++num_ones;
}
else
{
break;
}
}
// Remove dimensions with length=1 from front
new_right_shape.erase(std::begin(new_right_shape),
std::next(std::begin(new_right_shape), num_ones));
auto reshape_right = std::make_shared<ngraph::op::Reshape>(
right, ngraph::get_default_order(right_shape), new_right_shape);
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
auto broadcast_right = std::make_shared<ngraph::op::Broadcast>(
reshape_right,
left_shape,
calculate_broadcast_axes(left_shape, new_right_shape, start_match_axis));
return {left, broadcast_right};
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
......
......@@ -32,19 +32,45 @@ namespace ngraph
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(const NodeVector& inputs);
NodeVector numpy_style_broadcast(const NodeVector& inputs)
NGRAPH_DEPRECATED("Replace with numpy_style_value_broadcast");
/// \brief Cast shape of an input node to the requested output shape using NumPy's broadcasting rules
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param values Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
OutputVector numpy_style_broadcast_values(const OutputVector& values);
/// \brief Cast shape of an output to the requested output shape using NumPy's broadcasting rules
///
/// \param input_node original input node
/// \param value original value
/// \param shape requested output shape
///
/// \return Broadcast node.
std::shared_ptr<ngraph::Node>
numpy_style_broadcast(const std::shared_ptr<ngraph::Node>& input_node,
const Shape& shape);
/// \return Broadcast output.
std::shared_ptr<Node> numpy_style_broadcast(const Output<Node>& value, const Shape& shape);
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector legacy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right,
size_t start_match_axis)
NGRAPH_DEPRECATED("Replace with legacy_style_value_broadcast_for_binary_operation");
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
......@@ -59,10 +85,9 @@ namespace ngraph
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
OutputVector legacy_style_broadcast_values_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
......@@ -76,9 +101,24 @@ namespace ngraph
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
NodeVector numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
NGRAPH_DEPRECATED("Replace with numpy_style_broadcast_value_for_matmul_operation.");
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul` operation
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html)
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Generate a list of broadcast axes.
///
......@@ -118,22 +158,21 @@ namespace ngraph
output_shape, input_shape, output_shape.size() - input_shape.size());
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape)
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& output,
Shape new_shape)
{
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
return std::make_shared<op::Broadcast>(
output, new_shape, calculate_broadcast_axes(new_shape, output.get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
const ngraph::Shape& new_shape,
std::size_t start_match_axis)
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
const Shape& new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
return std::make_shared<op::Broadcast>(
value,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
}
} // namespace op
} // namespace ngraph
......@@ -69,32 +69,29 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
return afunc;
}
shared_ptr<Node> op::util::RNNCellBase::add(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::sub(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::mul(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::clip(const shared_ptr<Node>& data) const
shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const
{
if (m_clip == 0.f)
{
return data;
return data.as_single_output_node();
}
return make_shared<op::Clamp>(data, -m_clip, m_clip);
......
......@@ -81,8 +81,7 @@ namespace ngraph
///
/// \return Node with element-wise add operation.
///
static std::shared_ptr<Node> add(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
static std::shared_ptr<Node> add(const Output<Node>& lhs, const Output<Node>& rhs);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
......@@ -91,8 +90,7 @@ namespace ngraph
///
/// \return Node with element-wise subtract operation.
///
static std::shared_ptr<Node> sub(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
static std::shared_ptr<Node> sub(const Output<Node>& lhs, const Output<Node>& rhs);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
......@@ -101,8 +99,7 @@ namespace ngraph
///
/// \return Node with element-wise multiply operation.
///
static std::shared_ptr<Node> mul(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
static std::shared_ptr<Node> mul(const Output<Node>& lhs, const Output<Node>& rhs);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
......@@ -110,7 +107,7 @@ namespace ngraph
///
/// \return Node with element-wise clip operation.
///
std::shared_ptr<Node> clip(const std::shared_ptr<Node>& data) const;
std::shared_ptr<Node> clip(const Output<Node>& data) const;
private:
const std::size_t m_hidden_size;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment