Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
7b1dc3e3
Unverified
Commit
7b1dc3e3
authored
Jan 10, 2018
by
Robert Kimball
Committed by
GitHub
Jan 10, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix some is_functionally_identical methods (#365)
parent
7df687c1
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
80 additions
and
0 deletions
+80
-0
broadcast.cpp
src/ngraph/ops/broadcast.cpp
+16
-0
broadcast.hpp
src/ngraph/ops/broadcast.hpp
+2
-0
convert.cpp
src/ngraph/ops/convert.cpp
+15
-0
convert.hpp
src/ngraph/ops/convert.hpp
+2
-0
get_output_element.cpp
src/ngraph/ops/get_output_element.cpp
+5
-0
get_output_element.hpp
src/ngraph/ops/get_output_element.hpp
+2
-0
reduce_window.cpp
src/ngraph/ops/reduce_window.cpp
+5
-0
reduce_window.hpp
src/ngraph/ops/reduce_window.hpp
+2
-0
reverse.cpp
src/ngraph/ops/reverse.cpp
+15
-0
reverse.hpp
src/ngraph/ops/reverse.hpp
+2
-0
select_and_scatter.cpp
src/ngraph/ops/select_and_scatter.cpp
+5
-0
select_and_scatter.hpp
src/ngraph/ops/select_and_scatter.hpp
+2
-0
xla_get_tuple_element.cpp
src/ngraph/ops/xla_get_tuple_element.cpp
+5
-0
xla_get_tuple_element.hpp
src/ngraph/ops/xla_get_tuple_element.hpp
+2
-0
No files found.
src/ngraph/ops/broadcast.cpp
View file @
7b1dc3e3
...
@@ -45,3 +45,19 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
...
@@ -45,3 +45,19 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Sum
>
(
delta
,
m_broadcast_axes
));
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Sum
>
(
delta
,
m_broadcast_axes
));
}
}
bool
op
::
Broadcast
::
is_functionally_identical
(
const
Node
&
other
)
const
{
bool
rc
=
true
;
if
(
Node
::
is_functionally_identical
(
other
))
{
const
Broadcast
&
obj
=
dynamic_cast
<
const
Broadcast
&>
(
other
);
rc
&=
m_shape
==
obj
.
m_shape
;
rc
&=
m_broadcast_axes
==
obj
.
m_broadcast_axes
;
}
else
{
rc
=
false
;
}
return
rc
;
}
src/ngraph/ops/broadcast.hpp
View file @
7b1dc3e3
...
@@ -73,6 +73,8 @@ namespace ngraph
...
@@ -73,6 +73,8 @@ namespace ngraph
/// \return An set containing the indices of the broadcast axes (0-based).
/// \return An set containing the indices of the broadcast axes (0-based).
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
const
Shape
&
get_broadcast_shape
()
const
{
return
m_shape
;
}
const
Shape
&
get_broadcast_shape
()
const
{
return
m_shape
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
...
...
src/ngraph/ops/convert.cpp
View file @
7b1dc3e3
...
@@ -36,3 +36,18 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints,
...
@@ -36,3 +36,18 @@ void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints
.
add_delta
(
x
,
std
::
make_shared
<
op
::
Convert
>
(
delta
,
x
->
get_element_type
()));
adjoints
.
add_delta
(
x
,
std
::
make_shared
<
op
::
Convert
>
(
delta
,
x
->
get_element_type
()));
}
}
bool
op
::
Convert
::
is_functionally_identical
(
const
Node
&
other
)
const
{
bool
rc
=
true
;
if
(
Node
::
is_functionally_identical
(
other
))
{
const
Convert
&
obj
=
dynamic_cast
<
const
Convert
&>
(
other
);
rc
&=
m_element_type
==
obj
.
m_element_type
;
}
else
{
rc
=
false
;
}
return
rc
;
}
src/ngraph/ops/convert.hpp
View file @
7b1dc3e3
...
@@ -60,6 +60,8 @@ namespace ngraph
...
@@ -60,6 +60,8 @@ namespace ngraph
}
}
const
element
::
Type
&
get_convert_element_type
()
const
{
return
m_element_type
;
}
const
element
::
Type
&
get_convert_element_type
()
const
{
return
m_element_type
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
const
ngraph
::
element
::
Type
&
m_element_type
;
const
ngraph
::
element
::
Type
&
m_element_type
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
...
...
src/ngraph/ops/get_output_element.cpp
View file @
7b1dc3e3
...
@@ -30,3 +30,8 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
...
@@ -30,3 +30,8 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
set_value_type_checked
(
arg
->
get_output_element_type
(
n
),
arg
->
get_output_shape
(
n
));
set_value_type_checked
(
arg
->
get_output_element_type
(
n
),
arg
->
get_output_shape
(
n
));
}
}
bool
op
::
GetOutputElement
::
is_functionally_identical
(
const
Node
&
other
)
const
{
return
false
;
}
src/ngraph/ops/get_output_element.hpp
View file @
7b1dc3e3
...
@@ -58,6 +58,8 @@ namespace ngraph
...
@@ -58,6 +58,8 @@ namespace ngraph
/// \return The index of the tuple element to get.
/// \return The index of the tuple element to get.
size_t
get_n
()
const
{
return
m_n
;
}
size_t
get_n
()
const
{
return
m_n
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
size_t
m_n
;
size_t
m_n
;
};
};
...
...
src/ngraph/ops/reduce_window.cpp
View file @
7b1dc3e3
...
@@ -127,3 +127,8 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
...
@@ -127,3 +127,8 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
set_value_type_checked
(
input_reductee
.
get_element_type
(),
result_shape
);
set_value_type_checked
(
input_reductee
.
get_element_type
(),
result_shape
);
}
}
bool
op
::
ReduceWindow
::
is_functionally_identical
(
const
Node
&
other
)
const
{
return
false
;
}
src/ngraph/ops/reduce_window.hpp
View file @
7b1dc3e3
...
@@ -85,6 +85,8 @@ namespace ngraph
...
@@ -85,6 +85,8 @@ namespace ngraph
const
Shape
&
get_window_shape
()
const
{
return
m_window_shape
;
}
const
Shape
&
get_window_shape
()
const
{
return
m_window_shape
;
}
/// \return The window movement strides.
/// \return The window movement strides.
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
std
::
shared_ptr
<
Function
>
m_reduction_function
;
std
::
shared_ptr
<
Function
>
m_reduction_function
;
Shape
m_window_shape
;
Shape
m_window_shape
;
...
...
src/ngraph/ops/reverse.cpp
View file @
7b1dc3e3
...
@@ -50,3 +50,18 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints,
...
@@ -50,3 +50,18 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Reverse
>
(
delta
,
m_reversed_axes
));
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Reverse
>
(
delta
,
m_reversed_axes
));
}
}
bool
op
::
Reverse
::
is_functionally_identical
(
const
Node
&
other
)
const
{
bool
rc
=
true
;
if
(
Node
::
is_functionally_identical
(
other
))
{
const
Reverse
&
obj
=
dynamic_cast
<
const
Reverse
&>
(
other
);
rc
&=
m_reversed_axes
==
obj
.
m_reversed_axes
;
}
else
{
rc
=
false
;
}
return
rc
;
}
src/ngraph/ops/reverse.hpp
View file @
7b1dc3e3
...
@@ -60,6 +60,8 @@ namespace ngraph
...
@@ -60,6 +60,8 @@ namespace ngraph
/// \return The set of axes to reverse.
/// \return The set of axes to reverse.
const
AxisSet
&
get_reversed_axes
()
const
{
return
m_reversed_axes
;
}
const
AxisSet
&
get_reversed_axes
()
const
{
return
m_reversed_axes
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
const
std
::
shared_ptr
<
Node
>&
delta
)
override
;
...
...
src/ngraph/ops/select_and_scatter.cpp
View file @
7b1dc3e3
...
@@ -213,3 +213,8 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee
...
@@ -213,3 +213,8 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee
//
//
set_value_type_checked
(
input_selectee_element_type
,
input_selectee_shape
);
set_value_type_checked
(
input_selectee_element_type
,
input_selectee_shape
);
}
}
bool
op
::
SelectAndScatter
::
is_functionally_identical
(
const
Node
&
other
)
const
{
return
false
;
}
src/ngraph/ops/select_and_scatter.hpp
View file @
7b1dc3e3
...
@@ -88,6 +88,8 @@ namespace ngraph
...
@@ -88,6 +88,8 @@ namespace ngraph
const
Shape
&
get_window_shape
()
const
{
return
m_window_shape
;
}
const
Shape
&
get_window_shape
()
const
{
return
m_window_shape
;
}
/// \return The window movement strides.
/// \return The window movement strides.
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
std
::
shared_ptr
<
Function
>
m_selection_function
;
std
::
shared_ptr
<
Function
>
m_selection_function
;
std
::
shared_ptr
<
Function
>
m_scatter_function
;
std
::
shared_ptr
<
Function
>
m_scatter_function
;
...
...
src/ngraph/ops/xla_get_tuple_element.cpp
View file @
7b1dc3e3
...
@@ -53,3 +53,8 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const
...
@@ -53,3 +53,8 @@ const Nodes& op::XLAGetTupleElement::get_tuple_elements() const
{
{
return
get_tuple_value
()
->
get_tuple_elements
();
return
get_tuple_value
()
->
get_tuple_elements
();
}
}
bool
op
::
XLAGetTupleElement
::
is_functionally_identical
(
const
Node
&
other
)
const
{
return
false
;
}
src/ngraph/ops/xla_get_tuple_element.hpp
View file @
7b1dc3e3
...
@@ -64,6 +64,8 @@ namespace ngraph
...
@@ -64,6 +64,8 @@ namespace ngraph
/// \return The index of the tuple element to get.
/// \return The index of the tuple element to get.
size_t
get_n
()
const
{
return
m_n
;
}
size_t
get_n
()
const
{
return
m_n
;
}
bool
is_functionally_identical
(
const
Node
&
)
const
override
;
protected
:
protected
:
std
::
shared_ptr
<
XLANode
>
m_arg
;
std
::
shared_ptr
<
XLANode
>
m_arg
;
size_t
m_n
;
size_t
m_n
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment