Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
4630c37d
Commit
4630c37d
authored
Nov 15, 2017
by
Christian Convey
Committed by
Christian Convey
Nov 29, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Adds autobroadcast builder.
parent
13330d49
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
571 additions
and
2 deletions
+571
-2
Doxyfile.in
doc/doxygen/Doxyfile.in
+0
-2
CMakeLists.txt
src/ngraph/CMakeLists.txt
+1
-0
autobroadcast.cpp
src/ngraph/builder/autobroadcast.cpp
+212
-0
autobroadcast.hpp
src/ngraph/builder/autobroadcast.hpp
+126
-0
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
util.hpp
src/ngraph/util.hpp
+8
-0
CMakeLists.txt
test/CMakeLists.txt
+1
-0
builder_autobroadcast.cpp
test/builder_autobroadcast.cpp
+222
-0
No files found.
doc/doxygen/Doxyfile.in
View file @
4630c37d
...
...
@@ -5,6 +5,4 @@ OUTPUT_DIRECTORY = @CMAKE_CURRENT_BINARY_DIR@
INPUT = @CMAKE_SOURCE_DIR@/src
RECURSIVE = YES
EXTRACT_STATIC = YES
USE_MATHJAX = YES
src/ngraph/CMakeLists.txt
View file @
4630c37d
...
...
@@ -13,6 +13,7 @@
set
(
SRC
autodiff/adjoints.cpp
builder/autobroadcast.cpp
builder/reduce_ops.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
...
...
src/ngraph/builder/autobroadcast.cpp
0 → 100644
View file @
4630c37d
/*
Copyright 2017 Nervana Systems Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/common.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/util.hpp"
#include <cassert>
#include <numeric>
#include <sstream>
using
namespace
std
;
namespace
ngraph
{
namespace
builder
{
autobroadcast_incompatible_shapes
::
autobroadcast_incompatible_shapes
(
const
ngraph
::
Shape
&
shape1
,
const
ngraph
::
Shape
&
shape2
)
:
ngraph
::
ngraph_error
(
error_str
(
shape1
,
shape2
))
,
m_shape1
(
shape1
)
,
m_shape2
(
shape2
)
{
}
const
ngraph
::
Shape
&
autobroadcast_incompatible_shapes
::
get_shape1
()
const
{
return
m_shape1
;
}
const
ngraph
::
Shape
&
autobroadcast_incompatible_shapes
::
get_shape2
()
const
{
return
m_shape2
;
}
std
::
string
autobroadcast_incompatible_shapes
::
error_str
(
const
ngraph
::
Shape
&
shape1
,
const
ngraph
::
Shape
&
shape2
)
{
ostringstream
os
;
os
<<
"Auto-broadcast not possible for these input shapes:"
<<
" shape1="
<<
vector_to_string
(
shape1
)
<<
" shape2="
<<
vector_to_string
(
shape2
);
return
os
.
str
();
}
/// A utility struct representing the details computed by the
/// compute_shapes_and_broadcast_axes function.
struct
Autobroadcast_plan
{
ngraph
::
Shape
m_arg1_shape_after_possible_reshaping
;
ngraph
::
Shape
m_arg2_shape_after_possible_reshaping
;
ngraph
::
AxisSet
m_arg1_broadcast_axes
;
ngraph
::
AxisSet
m_arg2_broadcast_axes
;
ngraph
::
Shape
m_final_shape
;
};
/// @brief Compute the details regarding what reshape and/or broadcast operations must be applied to
/// arg1 and/or arg2, as well as what the final resulting shape shall be.
///
/// If this algorithm cannot handle the particular combination of shapes supplied as inputs, throw
/// an ngraph::builder::autobroadcast_incompatible_shapes exception.
///
/// @exception ngraph::builder::autobroadcast_incompatible_shapes
static
Autobroadcast_plan
compute_shapes_and_broadcast_axes
(
const
ngraph
::
Shape
&
arg1_in_shape
,
const
ngraph
::
Shape
&
arg2_in_shape
)
{
Autobroadcast_plan
plan
;
size_t
arg1_size
=
arg1_in_shape
.
size
();
size_t
arg2_size
=
arg2_in_shape
.
size
();
size_t
axis
=
std
::
max
(
arg1_size
,
arg2_size
)
-
1
;
// per numpy definition of broadcast:
// start with trailing dimensions and work forward
// two dimensions are compatible:
// * if they are equal
// * if one of them is 1
while
(
arg1_size
>=
1
||
arg2_size
>=
1
)
{
size_t
arg1_dim
=
arg1_size
?
arg1_in_shape
[
arg1_size
-
1
]
:
1
;
size_t
arg2_dim
=
arg2_size
?
arg2_in_shape
[
arg2_size
-
1
]
:
1
;
if
(
arg1_dim
==
arg2_dim
)
{
// add dimension to broadcast shape + arg1/arg2 reshape
plan
.
m_final_shape
.
insert
(
plan
.
m_final_shape
.
begin
(),
arg1_dim
);
plan
.
m_arg1_shape_after_possible_reshaping
.
insert
(
plan
.
m_arg1_shape_after_possible_reshaping
.
begin
(),
arg1_dim
);
plan
.
m_arg2_shape_after_possible_reshaping
.
insert
(
plan
.
m_arg2_shape_after_possible_reshaping
.
begin
(),
arg2_dim
);
}
else
if
(
arg2_dim
==
1
)
{
// add arg1 dimension to broadcast shape and arg1 reshape
plan
.
m_final_shape
.
insert
(
plan
.
m_final_shape
.
begin
(),
arg1_dim
);
plan
.
m_arg1_shape_after_possible_reshaping
.
insert
(
plan
.
m_arg1_shape_after_possible_reshaping
.
begin
(),
arg1_dim
);
// add current axis to arg2 broadcast axes
plan
.
m_arg2_broadcast_axes
.
insert
(
plan
.
m_arg2_broadcast_axes
.
begin
(),
axis
);
}
else
if
(
arg1_dim
==
1
)
{
// add arg2 dimension to broadcast shape and arg2 reshape
plan
.
m_final_shape
.
insert
(
plan
.
m_final_shape
.
begin
(),
arg2_dim
);
plan
.
m_arg2_shape_after_possible_reshaping
.
insert
(
plan
.
m_arg2_shape_after_possible_reshaping
.
begin
(),
arg2_dim
);
// add current axis to arg1 broadcast axes
plan
.
m_arg1_broadcast_axes
.
insert
(
plan
.
m_arg1_broadcast_axes
.
begin
(),
axis
);
}
else
{
throw
autobroadcast_incompatible_shapes
(
arg1_in_shape
,
arg2_in_shape
);
}
if
(
arg1_size
)
{
--
arg1_size
;
}
if
(
arg2_size
)
{
--
arg2_size
;
}
if
(
axis
)
{
--
axis
;
}
}
return
plan
;
}
/// If necessary, wrap \p node with an additional reshape and/or broadcast op.
/// Return a pointer to the node that produces the wrapped value.
/// If no additional reshape or broadcast op was needed, simply return \p node.
static
std
::
shared_ptr
<
Node
>
add_required_ops
(
const
std
::
shared_ptr
<
Node
>&
node
,
const
ngraph
::
Shape
&
node_shape_after_possible_reshaping
,
const
ngraph
::
AxisSet
&
node_broadcast_axes
,
const
ngraph
::
Shape
&
node_final_shape
)
{
std
::
shared_ptr
<
Node
>
return_node
{
node
};
if
(
node
->
get_shape
()
!=
node_shape_after_possible_reshaping
)
{
// tell reshape to examine input dimensions in order
ngraph
::
AxisVector
order
(
node
->
get_shape
().
size
());
std
::
iota
(
order
.
begin
(),
order
.
end
(),
0
);
return_node
=
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
return_node
,
order
,
node_shape_after_possible_reshaping
);
}
if
(
node_final_shape
!=
node_shape_after_possible_reshaping
)
{
return_node
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
return_node
,
node_final_shape
,
node_broadcast_axes
);
}
return
return_node
;
}
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>
numpy_broadcast
(
const
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>&
args
)
{
assert
(
args
.
first
);
assert
(
args
.
second
);
const
ngraph
::
Shape
&
arg1_in_shape
=
args
.
first
->
get_shape
();
const
ngraph
::
Shape
&
arg2_in_shape
=
args
.
second
->
get_shape
();
// Handle the trivial case...
if
(
arg1_in_shape
==
arg2_in_shape
)
{
return
args
;
}
Autobroadcast_plan
plan
=
compute_shapes_and_broadcast_axes
(
arg1_in_shape
,
arg2_in_shape
);
std
::
shared_ptr
<
Node
>
arg1_out
=
add_required_ops
(
args
.
first
,
plan
.
m_arg1_shape_after_possible_reshaping
,
plan
.
m_arg1_broadcast_axes
,
plan
.
m_final_shape
);
std
::
shared_ptr
<
Node
>
arg2_out
=
add_required_ops
(
args
.
second
,
plan
.
m_arg2_shape_after_possible_reshaping
,
plan
.
m_arg2_broadcast_axes
,
plan
.
m_final_shape
);
return
{
arg1_out
,
arg2_out
};
}
}
// namespace builder
}
// namespace ngraph
src/ngraph/builder/autobroadcast.hpp
0 → 100644
View file @
4630c37d
/*
Copyright 2017 Nervana Systems Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
#include <memory>
#include <utility>
namespace
ngraph
{
namespace
builder
{
class
autobroadcast_incompatible_shapes
:
public
ngraph
::
ngraph_error
{
public
:
autobroadcast_incompatible_shapes
(
const
ngraph
::
Shape
&
shape1
,
const
ngraph
::
Shape
&
shape2
);
const
ngraph
::
Shape
&
get_shape1
()
const
;
const
ngraph
::
Shape
&
get_shape2
()
const
;
private
:
const
ngraph
::
Shape
m_shape1
;
const
ngraph
::
Shape
m_shape2
;
static
std
::
string
error_str
(
const
ngraph
::
Shape
&
shape1
,
const
ngraph
::
Shape
&
shape2
);
};
/// @brief Wrap two graph nodes, if necessary, to obtain values with identical shapes,
/// using NumPy's auto-broadcast rules.
///
/// The elements in the std::pair returned by this function correspond to those supplied
/// in the std::pair provided via \p args.
///
/// If \p args.first and \p args.second produce identical shapes, then the returned std::pair
/// will have the same value as \p args.
///
/// If \p args.first and \p args.second produce different shapes, then this function creates
/// new ngraph::op::Reshape and/or ngraph::op::Broadcast nodes, as needed, to wrap
/// \p args.first and/or \p args.second in a manner that yields values with the same shape.
///
/// There are some shape combinations which the autobroadcast algoritm cannot handle.
/// An exception is thrown when such combinations are provided to this function.
///
/// @pre
/// - \p args.first is not null
/// - \p args.second is not null
///
/// @post
/// - The ngraph::Node objects pointed to by \p args.first and \p args.second have not been
/// altered by this function, except by possibly having added consumers of their values.
///
/// - If an exception was not thrown, then the return value's \p first and \p second
/// elements point to ngraph::Node objects whose output values have the same shape.
///
/// @exception ngraph::builder::autobroadcast_incompatible_shapes
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>
numpy_broadcast
(
const
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>&
args
);
/// Create a new \p NodeType node, and any additional nodes required to simulate NumPy-style autobroadcast
/// semantics. Intended for binary operations such as "Add".
///
/// @param [in] operand1_reshapeable The first operand to supply to the \p NodeType constructor. Subject to
/// being wrapped with additional nodes required for autobroadcasting. Must not be null.
///
/// @param [in] operand2_reshapeable The second operand to supply to the \p NodeType constructor. Subject to
/// being wrapped with additional nodes required for autobroadcasting. Must not be null.
///
/// @return The sink node of any/all nodes created by this function. Will never be null.
///
/// @exception ngraph::builder::autobroadcast_incompatible_shapes
template
<
typename
NodeType
>
std
::
shared_ptr
<
NodeType
>
make_with_numpy_broadcast
(
const
std
::
shared_ptr
<
Node
>&
operand1_reshapeable
,
const
std
::
shared_ptr
<
Node
>&
operand2_reshapeable
)
{
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>
shaped_op1_op2
=
numpy_broadcast
({
operand1_reshapeable
,
operand2_reshapeable
});
return
std
::
make_shared
<
NodeType
>
(
shaped_op1_op2
.
first
,
shaped_op1_op2
.
second
);
}
/// Create a new \p NodeType node, and any additional nodes required to simulate NumPy-style autobroadcast
/// semantics. Intended for non-binary operations such as "Select", where precisely the second and third
/// operands are subject to autobroadcast semantics.
///
/// @param [in] operand1 This operand is not subject to autobraodcast logic, and will be passed as-is as
/// the first argument to the \p NodeType constructor.
///
/// @param [in] operand2_reshapeable The second operand to supply to the \p NodeType constructor. Subject to
/// being wrapped with additional nodes required for autobroadcasting. Must not be null.
///
/// @param [in] operand3_reshapeable The third operand to supply to the \p NodeType constructor. Subject to
/// being wrapped with additional nodes required for autobroadcasting. Must not be null.
///
/// @return The sink node of any/all nodes created by this function. Will never be null.
///
/// @exception ngraph::builder::autobroadcast_incompatible_shapes
template
<
typename
NodeType
>
std
::
shared_ptr
<
NodeType
>
make_with_numpy_broadcast
(
const
std
::
shared_ptr
<
Node
>&
operand1
,
const
std
::
shared_ptr
<
Node
>&
operand2_reshapeable
,
const
std
::
shared_ptr
<
Node
>&
operand3_reshapeable
)
{
std
::
pair
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>
shaped_op2_op3
=
numpy_broadcast
({
operand2_reshapeable
,
operand3_reshapeable
});
return
std
::
make_shared
<
NodeType
>
(
operand1
,
shaped_op2_op3
.
first
,
shaped_op2_op3
.
second
);
}
}
// namespace builder
}
// namespace ngraph
src/ngraph/ngraph.hpp
View file @
4630c37d
...
...
@@ -41,6 +41,7 @@
/// @brief Convenience functions that create addional graph nodes to implement commonly-used
/// recipes, for example auto-broadcast.
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/common.hpp"
#include "ngraph/descriptor/buffer.hpp"
...
...
src/ngraph/util.hpp
View file @
4630c37d
...
...
@@ -46,6 +46,14 @@ namespace ngraph
return
ss
.
str
();
}
template
<
typename
T
>
static
std
::
string
vector_to_string
(
const
std
::
vector
<
T
>&
v
)
{
std
::
ostringstream
os
;
os
<<
"[ "
<<
ngraph
::
join
(
v
)
<<
" ]"
;
return
os
.
str
();
}
template
<
typename
U
,
typename
T
>
bool
contains
(
const
U
&
container
,
const
T
&
obj
)
{
...
...
test/CMakeLists.txt
View file @
4630c37d
...
...
@@ -22,6 +22,7 @@ include_directories(
)
set
(
SRC
builder_autobroadcast.cpp
builder_reduce_ops.cpp
autodiff.cpp
build_graph.cpp
...
...
test/builder_autobroadcast.cpp
0 → 100644
View file @
4630c37d
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>
getParamFromShape
(
const
ngraph
::
Shape
&
shape
)
{
return
std
::
make_shared
<
ngraph
::
op
::
Parameter
>
(
ngraph
::
element
::
Float32
::
element_type
(),
shape
);
}
inline
ngraph
::
Shape
getShapeFromParam
(
const
shared_ptr
<
ngraph
::
Node
>&
node
)
{
auto
type
=
std
::
dynamic_pointer_cast
<
const
ngraph
::
TensorViewType
>
(
node
->
get_value_type
());
return
type
->
get_shape
();
}
// input shapes are equal so AutoBroadcast does nothing
TEST
(
autobroadcast
,
no_broadcast_equal
)
{
ngraph
::
Shape
s2345
{
2
,
3
,
4
,
5
};
auto
lhs
=
getParamFromShape
(
s2345
);
auto
rhs
=
getParamFromShape
(
s2345
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_EQ
(
ab_lhs
,
lhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s2345
);
EXPECT_EQ
(
ab_rhs
,
rhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s2345
);
}
// input shapes are incompatable
TEST
(
autobroadcast
,
no_broadcast_incompatable
)
{
ngraph
::
Shape
s2345
{
2
,
3
,
4
,
5
};
ngraph
::
Shape
s6789
{
6
,
7
,
8
,
9
};
auto
lhs
=
getParamFromShape
(
s2345
);
auto
rhs
=
getParamFromShape
(
s6789
);
EXPECT_THROW
(
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
}),
ngraph
::
builder
::
autobroadcast_incompatible_shapes
);
}
// basic broadcast test
// 1D to 2D
// lhs broadcast to 2,3
TEST
(
autobroadcast
,
normal_broadcast_2d
)
{
ngraph
::
Shape
s3
{
3
};
ngraph
::
Shape
s23
{
2
,
3
};
auto
lhs
=
getParamFromShape
(
s3
);
auto
rhs
=
getParamFromShape
(
s23
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_NE
(
ab_lhs
,
lhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s23
);
EXPECT_EQ
(
ab_rhs
,
rhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s23
);
}
// basic broadcast test
// 2D to 3D
// lhs broadcast to 2,3,4
TEST
(
autobroadcast
,
normal_broadcast_3d
)
{
ngraph
::
Shape
s34
{
3
,
4
};
ngraph
::
Shape
s234
{
2
,
3
,
4
};
auto
lhs
=
getParamFromShape
(
s34
);
auto
rhs
=
getParamFromShape
(
s234
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_NE
(
ab_lhs
,
lhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s234
);
EXPECT_EQ
(
ab_rhs
,
rhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s234
);
}
// basic broadcast test
// 3D to 4D
// lhs broadcast to 2,3,4,5
TEST
(
autobroadcast
,
normal_broadcast_4d
)
{
ngraph
::
Shape
s345
{
3
,
4
,
5
};
ngraph
::
Shape
s2345
{
2
,
3
,
4
,
5
};
auto
lhs
=
getParamFromShape
(
s345
);
auto
rhs
=
getParamFromShape
(
s2345
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_NE
(
ab_lhs
,
lhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s2345
);
EXPECT_EQ
(
ab_rhs
,
rhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s2345
);
}
// basic reshape and broadcast test
// rhs reshape to 2,3,4 then
// rhs broadcast to 2,3,4,5
TEST
(
autobroadcast
,
reshape_1x_broadcast
)
{
ngraph
::
Shape
s2345
{
2
,
3
,
4
,
5
};
ngraph
::
Shape
s2341
{
2
,
3
,
4
,
1
};
auto
lhs
=
getParamFromShape
(
s2345
);
auto
rhs
=
getParamFromShape
(
s2341
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_EQ
(
ab_lhs
,
lhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s2345
);
EXPECT_NE
(
ab_rhs
,
rhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s2345
);
}
// same as above, but additionally
// lhs reshape to 2,4,5 then
// lhs broadcast to 2,3,4,5
TEST
(
autobroadcast
,
reshape_2x_broadcast
)
{
ngraph
::
Shape
s2145
{
2
,
1
,
4
,
5
};
ngraph
::
Shape
s2341
{
2
,
3
,
4
,
1
};
auto
lhs
=
getParamFromShape
(
s2145
);
auto
rhs
=
getParamFromShape
(
s2341
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
ngraph
::
Shape
s2345
{
2
,
3
,
4
,
5
};
EXPECT_NE
(
ab_lhs
,
lhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s2345
);
EXPECT_NE
(
ab_rhs
,
rhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s2345
);
}
// matching singular dimension on axis 2
// should not require reshape of either lhs or rhs
// i.e. this should be the same as normal broadcast casse
// rhs broadcast to 2,3,1,5
TEST
(
autobroadcast
,
broadcast_with_dim1
)
{
ngraph
::
Shape
s2315
{
2
,
3
,
1
,
5
};
ngraph
::
Shape
s315
{
3
,
1
,
5
};
auto
lhs
=
getParamFromShape
(
s2315
);
auto
rhs
=
getParamFromShape
(
s315
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_EQ
(
ab_lhs
,
lhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s2315
);
EXPECT_NE
(
ab_rhs
,
rhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s2315
);
}
// reshape only test
// rhs reshape to 1,3,4,5 with no broadcast
TEST
(
autobroadcast
,
broadcast_with_leading_dim1
)
{
ngraph
::
Shape
s1345
{
1
,
3
,
4
,
5
};
ngraph
::
Shape
s345
{
3
,
4
,
5
};
auto
lhs
=
getParamFromShape
(
s1345
);
auto
rhs
=
getParamFromShape
(
s345
);
auto
shaped
=
ngraph
::
builder
::
numpy_broadcast
({
lhs
,
rhs
});
const
shared_ptr
<
Node
>&
ab_lhs
=
shaped
.
first
;
const
shared_ptr
<
Node
>&
ab_rhs
=
shaped
.
second
;
EXPECT_EQ
(
ab_lhs
,
lhs
);
// no change
EXPECT_EQ
(
getShapeFromParam
(
ab_lhs
),
s1345
);
EXPECT_NE
(
ab_rhs
,
rhs
);
EXPECT_EQ
(
getShapeFromParam
(
ab_rhs
),
s1345
);
}
TEST
(
autobroadcast
,
make_node_2_args
)
{
ngraph
::
Shape
s21
{
2
,
1
};
ngraph
::
Shape
s23
{
2
,
3
};
auto
lhs
=
getParamFromShape
(
s21
);
auto
rhs
=
getParamFromShape
(
s23
);
shared_ptr
<
Node
>
op
=
ngraph
::
builder
::
make_with_numpy_broadcast
<
ngraph
::
op
::
Add
>
(
lhs
,
rhs
);
EXPECT_NE
(
op
,
nullptr
);
}
TEST
(
autobroadcast
,
make_node_3_args
)
{
ngraph
::
Shape
s21
{
2
,
1
};
ngraph
::
Shape
s23
{
2
,
3
};
auto
predicates
=
std
::
make_shared
<
ngraph
::
op
::
Parameter
>
(
ngraph
::
element
::
Bool
::
element_type
(),
s23
);
auto
lhs
=
getParamFromShape
(
s21
);
auto
rhs
=
getParamFromShape
(
s23
);
shared_ptr
<
Node
>
op
=
ngraph
::
builder
::
make_with_numpy_broadcast
<
ngraph
::
op
::
Select
>
(
predicates
,
lhs
,
rhs
);
EXPECT_NE
(
op
,
nullptr
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment