Commit caa11583 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[ONNX] Use v1::Broadcast in ONNX LpNorm (#4083)

* Use v1::Broadcast in ONNX LpNorm

* Missing include

* Missing include

* Include the default opset in lp_norm.cpp

* Reference Constant from default_opset

* Some extra comments
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 07fcc56b
......@@ -14,16 +14,18 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <numeric>
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "lp_norm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/common.hpp"
namespace ngraph
......@@ -37,22 +39,36 @@ namespace ngraph
NodeVector lp_norm(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
std::size_t valid_axis =
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::size_t valid_axis =
common::validate_axis(node, axis, data->get_shape().size());
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported.";
const AxisSet reduction_axes{valid_axis};
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, reduction_axes, static_cast<std::size_t>(p_norm));
norm = std::make_shared<ngraph::opset0::Broadcast>(
norm, data->get_shape(), reduction_axes);
data, AxisSet{valid_axis}, static_cast<std::size_t>(p_norm));
const auto target_shape = default_opset::Constant::create(
element::i64, Shape{data->get_shape().size()}, data->get_shape());
// Create a default axes order matching the data tensor rank and erase the
// element at the 'valid_axis' position. The erased element indicates the axis
// along which the data should be broadcasted.
std::vector<size_t> axes_values(data->get_shape().size());
std::iota(axes_values.begin(), axes_values.end(), 0);
axes_values.erase(axes_values.begin() + valid_axis);
const auto axes_mapping = default_opset::Constant::create(
element::i64, Shape{axes_values.size()}, axes_values);
norm = std::make_shared<default_opset::Broadcast>(
norm, target_shape, axes_mapping);
return {data / norm};
return {std::make_shared<default_opset::Divide>(data, norm)};
}
} // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment