Commit 16ac55e3 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] LSTM node (#1945)

parent a3cab07b
......@@ -88,6 +88,8 @@ add_library(onnx_import STATIC
op/log_softmax.hpp
op/lrn.cpp
op/lrn.hpp
op/lstm.cpp
op/lstm.hpp
op/matmul.cpp
op/matmul.hpp
op/max_pool.cpp
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector lstm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -14,9 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/slice.hpp"
#include "op/split.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
......@@ -82,37 +81,6 @@ namespace ngraph
{
namespace set_1
{
namespace detail
{
template <typename T>
inline T get_valid_array_index(T left, T right)
{
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes,
std::vector<std::size_t> starts,
std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
}
} // namespace detail
NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
......@@ -143,16 +111,7 @@ namespace ngraph
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(detail::make_ng_slice(
input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
return reshape::split(input, length_parts, axis_to_split);
}
} // namespace set_1
......
......@@ -56,6 +56,7 @@
#include "op/log.hpp"
#include "op/log_softmax.hpp"
#include "op/lrn.hpp"
#include "op/lstm.hpp"
#include "op/matmul.hpp"
#include "op/max.hpp"
#include "op/max_pool.hpp"
......@@ -183,6 +184,7 @@ namespace ngraph
REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("LSTM", 1, lstm);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
......
......@@ -19,9 +19,15 @@
#include <cmath> // std::floor
#include <cstddef> // std::size_t
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if, std::is_floating_point, std::is_integral
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
......@@ -100,6 +106,36 @@ namespace ngraph
return range;
}
/// \brief Makes a Constant Ngraph node.
///
/// \param[in] type The node element type.
/// \param[in] shape The tensor data shape.
/// \param[in] data The data to initialize node with.
///
/// \tparam T Input data value type.
///
/// \return The Ngraph node representing Constant data.
///
template <typename T>
std::shared_ptr<ngraph::Node> make_constant_node(const ngraph::element::Type& type,
const ngraph::Shape& shape,
const std::vector<T>& data)
{
std::shared_ptr<ngraph::Node> node;
// Make constant node filled with single value.
if (data.size() == 1)
{
node = std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, data);
node = make_broadcast_node(node, shape);
}
else
{
node = std::make_shared<ngraph::op::Constant>(type, shape, data);
}
return node;
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -15,11 +15,15 @@
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric>
#include <vector>
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "exceptions.hpp"
#include "utils/common.hpp"
......@@ -31,6 +35,33 @@ namespace ngraph
{
namespace reshape
{
namespace
{
inline std::size_t get_valid_array_index(std::size_t idx, std::size_t axis_size)
{
return std::min(idx, axis_size);
}
std::shared_ptr<op::Slice> make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& axes,
const std::vector<std::size_t>& starts,
const std::vector<std::size_t>& ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace anonymous
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis)
{
......@@ -206,6 +237,36 @@ namespace ngraph
node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape);
}
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis)
{
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(make_ng_slice(node, {axis}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
NodeVector
split(const std::shared_ptr<ngraph::Node>& node, std::size_t split_parts, int axis)
{
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = node->get_shape().size() + axis;
}
std::size_t length_axis_to_split{node->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts(split_parts,
length_axis_to_split / split_parts);
return split(node, length_parts, axis_to_split);
}
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
......@@ -141,6 +141,35 @@ namespace ngraph
std::size_t outermost_axes_count = 1,
std::size_t innermost_axes_count = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] length_parts The vector defining the lengts of each splitted part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] split_parts The number of parts we want to split input node at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing).
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
std::size_t split_parts,
int axis = 0);
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment