Commit d46330de authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Fix default axis for legacy broadcasting (#2614)

* [ONNX] Fix default axis for legacy broadcasting

* Add new function to get default axis for legacy broadcasting

* Add std:: prefix

* Remove unnecesary default_axis function

* Style check
parent 369e4a22
......@@ -31,7 +31,10 @@ namespace ngraph
{
inline NodeVector add(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
......
......@@ -31,7 +31,10 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
......
......@@ -32,7 +32,10 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
......
......@@ -31,7 +31,10 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment