Commit 9d1f2367 authored by tsocha's avatar tsocha Committed by Artur Wojcik

[ONNX] Slice op (#1610)

* [ONNX] Slice op

* Review fix pt. 1

* Review fix pt. 2
parent 302ab491
...@@ -88,11 +88,13 @@ add_library(onnx_import STATIC ...@@ -88,11 +88,13 @@ add_library(onnx_import STATIC
op/relu.hpp op/relu.hpp
op/reshape.cpp op/reshape.cpp
op/reshape.hpp op/reshape.hpp
op/shape.cpp
op/shape.hpp
op/selu.cpp op/selu.cpp
op/selu.hpp op/selu.hpp
op/shape.hpp
op/shape.cpp
op/sigmoid.hpp op/sigmoid.hpp
op/slice.cpp
op/slice.hpp
op/softmax.cpp op/softmax.cpp
op/softmax.hpp op/softmax.hpp
op/softplus.cpp op/softplus.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/slice.hpp"
#include "slice.hpp"
#include "utils/common.hpp"
static inline int64_t get_valid_array_idx(int64_t idx, int64_t last_idx)
{
return (idx >= 0) ? std::min(idx, last_idx) : std::max<int64_t>(0, last_idx + idx);
}
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector slice(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
Shape data_shape = data->get_shape();
auto starts = node.get_attribute_value<std::vector<int64_t>>("starts");
auto ends = node.get_attribute_value<std::vector<int64_t>>("ends");
auto axes = node.get_attribute_value<std::vector<int64_t>>(
"axes", common::get_monotonic_range<int64_t>(data_shape.size()));
Shape lower_bounds(data_shape.size());
Shape upper_bounds = data_shape;
for (auto idx = 0; idx < axes.size(); ++idx)
{
size_t axis = axes.at(idx);
lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) = get_valid_array_idx(ends.at(idx), data_shape.at(axis));
}
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector slice(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -61,6 +61,7 @@ ...@@ -61,6 +61,7 @@
#include "op/selu.hpp" #include "op/selu.hpp"
#include "op/shape.hpp" #include "op/shape.hpp"
#include "op/sigmoid.hpp" #include "op/sigmoid.hpp"
#include "op/slice.hpp"
#include "op/softmax.hpp" #include "op/softmax.hpp"
#include "op/softplus.hpp" #include "op/softplus.hpp"
#include "op/softsign.hpp" #include "op/softsign.hpp"
...@@ -171,9 +172,10 @@ namespace ngraph ...@@ -171,9 +172,10 @@ namespace ngraph
std::bind(op::reduce_sum_square, std::placeholders::_1)); std::bind(op::reduce_sum_square, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1)); m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1));
m_map.emplace("Selu", std::bind(op::selu, std::placeholders::_1)); m_map.emplace("Selu", std::bind(op::selu, std::placeholders::_1));
m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1));
m_map.emplace("Sigmoid", std::bind(op::sigmoid, std::placeholders::_1)); m_map.emplace("Sigmoid", std::bind(op::sigmoid, std::placeholders::_1));
m_map.emplace("Slice", std::bind(op::slice, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1)); m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Softplus", std::bind(op::softplus, std::placeholders::_1)); m_map.emplace("Softplus", std::bind(op::softplus, std::placeholders::_1));
m_map.emplace("Softsign", std::bind(op::softsign, std::placeholders::_1)); m_map.emplace("Softsign", std::bind(op::softsign, std::placeholders::_1));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment