Commit d3da8b36 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[ONNX] Change Slice and Where to produce v1 ops (#4140)

* Changed onnx slice to produce v1

* onnx where produces v1 ops
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 09b7e413
......@@ -18,9 +18,8 @@
#include <memory>
#include <vector>
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "slice.hpp"
#include "utils/common.hpp"
namespace
......@@ -43,6 +42,7 @@ namespace ngraph
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
Shape data_shape = data->get_shape();
const auto data_rank = data_shape.size();
auto starts = node.get_attribute_value<std::vector<int64_t>>("starts");
auto ends = node.get_attribute_value<std::vector<int64_t>>("ends");
......@@ -50,7 +50,7 @@ namespace ngraph
auto axes = node.get_attribute_value<std::vector<int64_t>>(
"axes", common::get_monotonic_range<int64_t>(data_shape.size()));
Shape lower_bounds(data_shape.size());
Shape lower_bounds(data_rank);
Shape upper_bounds = data_shape;
for (size_t idx = 0; idx < axes.size(); ++idx)
......@@ -72,8 +72,20 @@ namespace ngraph
}
}
return {
std::make_shared<ngraph::opset0::Slice>(data, lower_bounds, upper_bounds)};
const auto begin = ngraph::op::Constant::create(
element::i64, Shape{lower_bounds.size()}, lower_bounds);
const auto end = ngraph::op::Constant::create(
element::i64, Shape{upper_bounds.size()}, upper_bounds);
const auto strides = ngraph::op::Constant::create(
element::i64, Shape{data_rank}, std::vector<int64_t>(data_rank, 1));
return {std::make_shared<default_opset::StridedSlice>(
data,
begin,
end,
strides,
std::vector<int64_t>(data_rank, 0),
std::vector<int64_t>(data_rank, 0))};
}
} // namespace set_1
......
......@@ -19,9 +19,8 @@
#include <memory>
#include "core/node.hpp"
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp"
namespace ngraph
{
......@@ -33,9 +32,9 @@ namespace ngraph
{
inline NodeVector where(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::opset0::Select>(
return {std::make_shared<default_opset::Select>(
ng_inputs.at(0), ng_inputs.at(1), ng_inputs.at(2))};
}
} // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment