Commit 76b8b4d4 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Fix MatMul op for vec @ tensor multiplication (#1969)

* Add static keyword for helper function.

* Fix MatMul for cases where left hand side is 1D vector.

- Add unit-test for this case.

* Add new line at the end of file.

* Log warning when dealing with scalars

* Apply clang-format

* Review: fix spelling, rename test model.
parent 6e234d65
...@@ -20,19 +20,21 @@ ...@@ -20,19 +20,21 @@
#include <vector> #include <vector>
#include "ngraph/coordinate.hpp" #include "ngraph/coordinate.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "exceptions.hpp"
#include "matmul.hpp" #include "matmul.hpp"
#include "utils/broadcasting.hpp" #include "utils/broadcasting.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
/// \brief Slice the sub matrix from 3D input tensor. /// \brief Slice the sub matrix from the input tensor.
/// ///
/// \param[in] node The input tensor. Must be 3D. /// \param[in] node The input tensor. Must be at most of rank 3.
/// \param[in] idx The index on the first axis, at which to slice sub-matrix. /// \param[in] idx The index on the first axis, at which to slice sub-matrix.
/// ///
/// \return The node representing sub matrix. /// \return The node representing sub matrix.
...@@ -41,6 +43,10 @@ static std::shared_ptr<ngraph::Node> get_sub_matrix(const std::shared_ptr<ngraph ...@@ -41,6 +43,10 @@ static std::shared_ptr<ngraph::Node> get_sub_matrix(const std::shared_ptr<ngraph
std::size_t idx) std::size_t idx)
{ {
const ngraph::Shape& shape{node->get_shape()}; const ngraph::Shape& shape{node->get_shape()};
if (shape.size() < 3)
{
return node;
}
// Below bounds defines the sub_matrix through ranges for each input node axis. // Below bounds defines the sub_matrix through ranges for each input node axis.
ngraph::Coordinate lower_bounds(shape.size()); ngraph::Coordinate lower_bounds(shape.size());
ngraph::Coordinate upper_bounds = shape; ngraph::Coordinate upper_bounds = shape;
...@@ -71,24 +77,35 @@ namespace ngraph ...@@ -71,24 +77,35 @@ namespace ngraph
std::size_t left_rank{left->get_shape().size()}; std::size_t left_rank{left->get_shape().size()};
std::size_t right_rank{right->get_shape().size()}; std::size_t right_rank{right->get_shape().size()};
// First (easy) case: if (left_rank == 0 || right_rank == 0)
// Multiply two tensors where one of them or both has rank lower equal 2. {
// This is already internally handled by Ngraph Dot operator. NGRAPH_WARN
if (left_rank <= 2 || right_rank <= 2) << (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
// First (easy) case that is already internally handled by Ngraph Dot operator.
// Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2)
{ {
return { return {
std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
} }
// Second case: // Second case:
// Multiply two tensors where each of them is rank greater equal 3. // Multiply two tensors where at least one of them is rank greater equal 3.
// Broadcast input arguments. // Broadcast input arguments only if both of them are not vectors.
const NodeVector& broadcasted_nodes = if (left_rank > 1 && right_rank > 1)
numpy_style_broadcast_for_matmul_operation(left, right); {
const NodeVector& broadcasted_nodes =
numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0); left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1); right = broadcasted_nodes.at(1);
}
const auto& left_shape = left->get_shape(); const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape(); const auto& right_shape = right->get_shape();
...@@ -97,11 +114,20 @@ namespace ngraph ...@@ -97,11 +114,20 @@ namespace ngraph
if (left_shape.size() > 3) if (left_shape.size() > 3)
{ {
left = reshape::collapse(left, 0, left_shape.size() - 3); left = reshape::collapse(left, 0, left_shape.size() - 3);
}
if (right_shape.size() > 3)
{
right = reshape::collapse(right, 0, right_shape.size() - 3); right = reshape::collapse(right, 0, right_shape.size() - 3);
} }
// Perform multiple small dot products // Perform multiple small dot products
std::size_t groups = left->get_shape().at(0); std::size_t groups = left->get_shape().at(0);
// If we haven't broadcast earlier this means that one of the inputs is a vector,
// thus the number of groups is defined by the shape of the bigger tensor.
if (right->get_shape().size() > left->get_shape().size())
{
groups = right->get_shape().at(0);
}
NodeVector small_dots(groups); NodeVector small_dots(groups);
for (std::size_t g = 0; g < groups; ++g) for (std::size_t g = 0; g < groups; ++g)
...@@ -119,7 +145,7 @@ namespace ngraph ...@@ -119,7 +145,7 @@ namespace ngraph
// Concatenate sub_dots on groups axis. // Concatenate sub_dots on groups axis.
auto result = std::make_shared<ngraph::op::Concat>(small_dots, 0); auto result = std::make_shared<ngraph::op::Concat>(small_dots, 0);
if (left_shape.size() <= 3) if (left_shape.size() <= 3 && right_shape.size() <= 3)
{ {
return {result}; return {result};
} }
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
/// \param left_shape Shape of first input tensor. /// \param left_shape Shape of first input tensor.
/// \param right_shape Shape of the second input tensor. /// \param right_shape Shape of the second input tensor.
/// \return Shape of the output tensor and full shape of input tensors. /// \return Shape of the output tensor and full shape of input tensors.
std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_shape, static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_shape,
ngraph::Shape right_shape) ngraph::Shape right_shape)
{ {
ngraph::Shape output_shape; ngraph::Shape output_shape;
auto rank_left = left_shape.size(); auto rank_left = left_shape.size();
......
ONNXNgraphImporter:a

A
BC"MatMul compute_graphZ
A

Z
B



b
C


B
\ No newline at end of file
...@@ -1346,3 +1346,19 @@ TEST(onnx, model_custom_op_default_domain) ...@@ -1346,3 +1346,19 @@ TEST(onnx, model_custom_op_default_domain)
Outputs outputs{execute(function, inputs, "INTERPRETER")}; Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
} }
TEST(onnx, model_matmul_vec_ten3d)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_vec_ten3d.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0.f, 1.f});
inputs.emplace_back(
test::NDArray<float, 3>{{{0.f}, {1.f}}, {{2.f}, {3.f}}, {{4.f}, {5.f}}}.get_vector());
Outputs expected_output{test::NDArray<float, 2>{{1.f}, {3.f}, {5.f}}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment