Commit 92c1cc19 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Generic N-dimensional MatMul operation. (#1921)

* [WIP] Draft for matmul

* Numpy-style broadcasting for matrix multiplication.

* 3D matrix multiplication with one big Dot/slice/concat.

* Generic ND matmul implementation using slice/dot/concat pattern.

* Code formatting.

* remove unused header

* Add missing header

* Utility reshape-like functions.

* Use utility functions.

* Review comments.

* Code format

* Use if/else instead of ternary operator for readability.

* Remove unused function overloading

* Utility function expanding tensor shape with empty axes.

* Use helper functions.

* Use type for auto variable initializer to fix Centos build

* Fix Centos build errors.
parent 61df6725
......@@ -85,6 +85,7 @@ add_library(onnx_import STATIC
op/log_softmax.hpp
op/lrn.cpp
op/lrn.hpp
op/matmul.cpp
op/matmul.hpp
op/max_pool.cpp
op/max_pool.hpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <iterator>
#include <memory>
#include <vector>
#include "ngraph/coordinate.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/shape.hpp"
#include "matmul.hpp"
#include "utils/broadcasting.hpp"
#include "utils/reshape.hpp"
/// \brief Slice the sub matrix from 3D input tensor.
///
/// \param[in] node The input tensor. Must be 3D.
/// \param[in] idx The index on the first axis, at which to slice sub-matrix.
///
/// \return The node representing sub matrix.
///
static std::shared_ptr<ngraph::Node> get_sub_matrix(const std::shared_ptr<ngraph::Node>& node,
std::size_t idx)
{
const ngraph::Shape& shape{node->get_shape()};
// Below bounds defines the sub_matrix through ranges for each input node axis.
ngraph::Coordinate lower_bounds(shape.size());
ngraph::Coordinate upper_bounds = shape;
// We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
// two dimensions at index `idx` of first axis.
lower_bounds.at(0) = idx;
upper_bounds.at(0) = idx + 1;
auto sub_matrix = std::shared_ptr<ngraph::Node>{
std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds)};
// Remove first single entry dim.
return ngraph::onnx_import::reshape::squeeze(sub_matrix);
}
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector matmul(const Node& node)
{
const NodeVector& ng_inputs{node.get_ng_inputs()};
auto left = std::shared_ptr<ngraph::Node>{ng_inputs.at(0)};
auto right = std::shared_ptr<ngraph::Node>{ng_inputs.at(1)};
std::size_t left_rank{left->get_shape().size()};
std::size_t right_rank{right->get_shape().size()};
// First (easy) case:
// Multiply two tensors where one of them or both has rank lower equal 2.
// This is already internally handled by Ngraph Dot operator.
if (left_rank <= 2 || right_rank <= 2)
{
return {
std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
// Second case:
// Multiply two tensors where each of them is rank greater equal 3.
// Broadcast input arguments.
const NodeVector& broadcasted_nodes =
numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
// Collapse both tensors _stack of matrices_ axes (all except the last two).
// This will make easier further dot product calculations.
if (left_shape.size() > 3)
{
left = reshape::collapse(left, 0, left_shape.size() - 3);
right = reshape::collapse(right, 0, right_shape.size() - 3);
}
// Perform multiple small dot products
std::size_t groups = left->get_shape().at(0);
NodeVector small_dots(groups);
for (std::size_t g = 0; g < groups; ++g)
{
const auto& sliced_left = get_sub_matrix(left, g);
const auto& sliced_right = get_sub_matrix(right, g);
auto sub_dot = std::make_shared<ngraph::op::Dot>(sliced_left, sliced_right);
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
small_dots.at(g) = reshape::add_empty_axes(sub_dot);
}
// Concatenate sub_dots on groups axis.
auto result = std::make_shared<ngraph::op::Concat>(small_dots, 0);
if (left_shape.size() <= 3)
{
return {result};
}
// Expand result _stack of matrices_ axes to get expected result shape.
else
{
const Shape& shape{result->get_shape()};
Shape result_shape(std::next(std::begin(shape)), std::end(shape));
result_shape.insert(
std::begin(result_shape),
std::begin(left_shape),
std::next(std::begin(left_shape), left_shape.size() - 2));
return {std::make_shared<ngraph::op::Reshape>(
result, reshape::get_default_axis_vector(shape.size()), result_shape)};
}
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -17,7 +17,6 @@
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/dot.hpp"
#include "core/node.hpp"
......@@ -29,12 +28,7 @@ namespace ngraph
{
namespace set_1
{
inline NodeVector matmul(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
NodeVector matmul(const Node& node);
} // namespace set_1
} //namespace op
......
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <iterator>
#include <numeric>
#include <vector>
......@@ -30,7 +31,7 @@
/// \param left_shape Shape of first input tensor.
/// \param right_shape Shape of the second input tensor.
/// \return Shape of the output tensor and full shape of input tensors.
static std::vector<ngraph::Shape> calculate_numpy_broadcast_shape(ngraph::Shape left_shape,
std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_shape,
ngraph::Shape right_shape)
{
ngraph::Shape output_shape;
......@@ -54,53 +55,101 @@ static std::vector<ngraph::Shape> calculate_numpy_broadcast_shape(ngraph::Shape
return {output_shape, left_shape, right_shape};
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
/// it should be a superset of it (containing it as a continuous subset). This implies
/// we may expand the number of axes of input node.
///
/// \param[in] node The input Node to be broadcasted.
/// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node.
///
/// \return The boroadcasted Node.
///
static std::shared_ptr<ngraph::Node> broadcast(const std::shared_ptr<ngraph::Node>& node,
const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape)
{
ngraph::AxisVector broadcast_axes;
ngraph::Shape squeezed_shape;
// Positions of axes which have length of 1 are needed to calculate broadcast_axes
// for nGraph broadcast operation. We need to remove all ones from source shape
// to avoid broadcasting axis conflict.
for (std::size_t index = 0; index < output_shape.size(); ++index)
{
if (source_shape.at(index) == 1)
{
broadcast_axes.push_back(index);
}
else
{
squeezed_shape.push_back(source_shape.at(index));
}
}
// Remove axes which have length of 1 from source shape
auto broadcasted_node = std::make_shared<ngraph::op::Reshape>(
node,
ngraph::onnx_import::reshape::get_default_axis_vector(node->get_shape().size()),
squeezed_shape);
return std::make_shared<ngraph::op::Broadcast>(broadcasted_node, output_shape, broadcast_axes);
}
namespace ngraph
{
namespace onnx_import
{
NodeVector numpy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
NodeVector
numpy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
auto left_shape = left->get_shape();
auto right_shape = right->get_shape();
auto numpy_shapes = calculate_numpy_broadcast_shape(left_shape, right_shape);
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
const auto& numpy_shapes = get_numpy_broadcast_shape(left_shape, right_shape);
auto output_shape = numpy_shapes.at(0);
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
AxisVector left_broadcast_axes;
AxisVector right_broadcast_axes;
Shape new_left_shape;
Shape new_right_shape;
// Positions of dims which have length of 1 are needed to calculate broadcast_axes for nGraph broadcast operation.
// We need to remove all ones from source shape (left_broadcast_axes) to avoid broadcasting axis conflict.
for (auto index = 0; index < output_shape.size(); ++index)
{
(left_full_shape.at(index) == 1)
? left_broadcast_axes.push_back(index)
: new_left_shape.push_back(left_full_shape.at(index));
(right_full_shape.at(index) == 1)
? right_broadcast_axes.push_back(index)
: new_right_shape.push_back(right_full_shape.at(index));
return {broadcast(left, output_shape, left_full_shape),
broadcast(right, output_shape, right_full_shape)};
}
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_left = std::make_shared<op::Reshape>(
left, reshape::get_default_axis_vector(left->get_shape().size()), new_left_shape);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_right = std::make_shared<op::Reshape>(
right,
reshape::get_default_axis_vector(right->get_shape().size()),
new_right_shape);
broadcasted_left = std::make_shared<op::Broadcast>(
broadcasted_left, output_shape, left_broadcast_axes);
broadcasted_right = std::make_shared<op::Broadcast>(
broadcasted_right, output_shape, right_broadcast_axes);
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shape(
Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.at(0);
auto right_output_shape = numpy_shapes.at(0);
// Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_output_shape.insert(std::end(right_output_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcasted_left, broadcasted_right};
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
// Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_full_shape.insert(std::end(right_full_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcast(left, left_output_shape, left_full_shape),
broadcast(right, right_output_shape, right_full_shape)};
}
NodeVector
......@@ -108,8 +157,8 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis)
{
auto left_shape = left->get_shape();
auto right_shape = right->get_shape();
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
......
......@@ -67,6 +67,22 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPys' `matmul` operation
/// \link https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
......@@ -123,5 +139,4 @@ namespace ngraph
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
}
} // namespace onnx_import
} // namespace ngraph
......@@ -15,15 +15,11 @@
//*****************************************************************************
#include <algorithm>
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric>
#include <stdexcept>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "exceptions.hpp"
#include "utils/common.hpp"
......@@ -145,6 +141,71 @@ namespace ngraph
return reorder_axes(node, axes_order);
}
std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes)
{
if (axes.empty())
{
return node;
}
Shape in_shape{node->get_shape()};
for (std::size_t idx = 0; idx < axes.size(); ++idx)
{
in_shape.at(idx) = 0;
}
Shape output_shape;
for (auto axis : in_shape)
{
if (axis != 0)
{
output_shape.push_back(axis);
}
}
return reshape(node, output_shape);
}
std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node,
const std::size_t start_axis,
const std::size_t end_axis)
{
auto shape = node->get_shape();
std::size_t collapsed_axis_size =
std::accumulate(std::next(std::begin(shape), start_axis),
std::next(std::begin(shape), end_axis + 1),
1UL,
std::multiplies<std::size_t>());
Shape output_shape{collapsed_axis_size};
output_shape.insert(std::end(output_shape),
std::next(std::begin(shape), end_axis + 1),
std::end(shape));
return reshape(node, output_shape);
}
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const AxisVector& axis_order,
const Shape& shape)
{
return std::make_shared<ngraph::op::Reshape>(
node, get_default_axis_vector(node->get_shape().size()), shape);
}
std::shared_ptr<ngraph::Node> add_empty_axes(const std::shared_ptr<ngraph::Node>& node,
std::size_t outermost_axes_count,
std::size_t innermost_axes_count)
{
// Add outermost empty dimensions.
Shape output_shape(outermost_axes_count, 1);
output_shape.insert(std::end(output_shape),
std::begin(node->get_shape()),
std::end(node->get_shape()));
// Add innermost empty dimensions.
output_shape.insert(std::end(output_shape), innermost_axes_count, 1);
return std::make_shared<ngraph::op::Reshape>(
node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape);
}
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
......@@ -16,8 +16,15 @@
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -79,6 +86,61 @@ namespace ngraph
/// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] node The node to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return The squeezed node.
///
std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes = {0});
/// \brief Collapse specified axes into single one.
///
/// \note Collapsed axes create a continuous range starting from outermost axis.
///
/// \param[in] node The node to be reshaped.
/// \param[in] start_axis The start axis index.
/// \param[in] end_axis The end axis (inclusive) index.
///
/// \return The node with collapsed specified axes.
///
std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node,
const std::size_t start_axis,
const std::size_t end_axis);
/// \brief Change shape of input tensor.
///
/// \param[in] node The node which shape will be changed.
/// \param[in] shape The new shape for input tensor.
///
/// \return The node representing reshaped input tensor.
///
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const AxisVector& axis_order,
const Shape& shape);
inline std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const Shape& shape)
{
return reshape(node, get_default_axis_vector(node->get_shape().size()), shape);
}
/// \brief Expands node tensor shape with empty axes.
///
/// \param[in] node The node to be expanded.
/// \param[in] outermost_axes_count The number of added outermost axes.
/// At the front of the shape.
/// \param[in] innermost_axes_count The number of added innermost axes.
/// At the end of the shape.
///
/// \return The node with added empty axes.
///
std::shared_ptr<ngraph::Node> add_empty_axes(const std::shared_ptr<ngraph::Node>& node,
std::size_t outermost_axes_count = 1,
std::size_t innermost_axes_count = 0);
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment