Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8ee374fc
Unverified
Commit
8ee374fc
authored
Mar 09, 2020
by
Yimei Sun
Committed by
GitHub
Mar 09, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace copy_with_new_args in B op set (#4426)
Co-authored-by:
Scott Cyphers
<
diyessi@users.noreply.github.com
>
parent
9a52ae47
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
30 additions
and
27 deletions
+30
-27
batch_norm.cpp
src/ngraph/op/batch_norm.cpp
+5
-3
batch_norm.hpp
src/ngraph/op/batch_norm.hpp
+6
-6
binary_convolution.cpp
src/ngraph/op/binary_convolution.cpp
+2
-1
binary_convolution.hpp
src/ngraph/op/binary_convolution.hpp
+2
-2
broadcast.cpp
src/ngraph/op/broadcast.cpp
+3
-3
broadcast.hpp
src/ngraph/op/broadcast.hpp
+6
-6
broadcast_distributed.cpp
src/ngraph/op/broadcast_distributed.cpp
+1
-1
broadcast_distributed.hpp
src/ngraph/op/broadcast_distributed.hpp
+2
-2
copy.cpp
test/copy.cpp
+3
-3
No files found.
src/ngraph/op/batch_norm.cpp
View file @
8ee374fc
...
@@ -74,7 +74,8 @@ void op::BatchNormTraining::validate_and_infer_types()
...
@@ -74,7 +74,8 @@ void op::BatchNormTraining::validate_and_infer_types()
set_output_type
(
2
,
result_et
,
result_channel_shape
);
set_output_type
(
2
,
result_et
,
result_channel_shape
);
}
}
std
::
shared_ptr
<
Node
>
op
::
BatchNormTraining
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
BatchNormTraining
::
clone_with_new_inputs
(
const
OutputVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
std
::
make_shared
<
BatchNormTraining
>
(
return
std
::
make_shared
<
BatchNormTraining
>
(
...
@@ -165,7 +166,8 @@ void op::BatchNormInference::validate_and_infer_types()
...
@@ -165,7 +166,8 @@ void op::BatchNormInference::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_batch_shape
);
set_output_type
(
0
,
result_et
,
result_batch_shape
);
}
}
std
::
shared_ptr
<
Node
>
op
::
BatchNormInference
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
BatchNormInference
::
clone_with_new_inputs
(
const
OutputVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
std
::
make_shared
<
BatchNormInference
>
(
return
std
::
make_shared
<
BatchNormInference
>
(
...
@@ -258,7 +260,7 @@ void op::BatchNormTrainingBackprop::validate_and_infer_types()
...
@@ -258,7 +260,7 @@ void op::BatchNormTrainingBackprop::validate_and_infer_types()
}
}
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
op
::
BatchNormTrainingBackprop
::
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
op
::
BatchNormTrainingBackprop
::
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
std
::
make_shared
<
op
::
BatchNormTrainingBackprop
>
(
new_args
.
at
(
2
),
return
std
::
make_shared
<
op
::
BatchNormTrainingBackprop
>
(
new_args
.
at
(
2
),
...
...
src/ngraph/op/batch_norm.hpp
View file @
8ee374fc
...
@@ -77,8 +77,8 @@ namespace ngraph
...
@@ -77,8 +77,8 @@ namespace ngraph
double
get_eps_value
()
const
{
return
m_epsilon
;
}
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
...
@@ -144,8 +144,8 @@ namespace ngraph
...
@@ -144,8 +144,8 @@ namespace ngraph
double
get_eps_value
()
const
{
return
m_epsilon
;
}
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
/* adjoints */
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
/* adjoints */
,
...
@@ -194,8 +194,8 @@ namespace ngraph
...
@@ -194,8 +194,8 @@ namespace ngraph
double
get_eps_value
()
const
{
return
m_epsilon
;
}
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
private
:
private
:
static
constexpr
size_t
INPUT_GAMMA
=
0
;
static
constexpr
size_t
INPUT_GAMMA
=
0
;
...
...
src/ngraph/op/binary_convolution.cpp
View file @
8ee374fc
...
@@ -129,7 +129,8 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
...
@@ -129,7 +129,8 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
set_output_type
(
0
,
data_batch_et
,
result_shape
);
set_output_type
(
0
,
data_batch_et
,
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
v1
::
BinaryConvolution
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
v1
::
BinaryConvolution
::
clone_with_new_inputs
(
const
OutputVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
v1
::
BinaryConvolution
>
(
new_args
.
at
(
0
),
return
make_shared
<
v1
::
BinaryConvolution
>
(
new_args
.
at
(
0
),
...
...
src/ngraph/op/binary_convolution.hpp
View file @
8ee374fc
...
@@ -76,8 +76,8 @@ namespace ngraph
...
@@ -76,8 +76,8 @@ namespace ngraph
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
OutputVector
&
deltas
)
override
;
const
OutputVector
&
deltas
)
override
;
...
...
src/ngraph/op/broadcast.cpp
View file @
8ee374fc
...
@@ -246,7 +246,7 @@ void op::v1::Broadcast::validate_and_infer_types()
...
@@ -246,7 +246,7 @@ void op::v1::Broadcast::validate_and_infer_types()
set_output_type
(
0
,
get_input_element_type
(
0
),
result_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
v1
::
Broadcast
::
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
v1
::
Broadcast
::
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
v1
::
Broadcast
>
(
return
make_shared
<
v1
::
Broadcast
>
(
...
@@ -339,7 +339,7 @@ void op::v0::Broadcast::validate_and_infer_types()
...
@@ -339,7 +339,7 @@ void op::v0::Broadcast::validate_and_infer_types()
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
}
}
shared_ptr
<
Node
>
op
::
v0
::
Broadcast
::
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
v0
::
Broadcast
::
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
v0
::
Broadcast
>
(
new_args
.
at
(
0
),
m_shape
,
m_broadcast_axes
);
return
make_shared
<
v0
::
Broadcast
>
(
new_args
.
at
(
0
),
m_shape
,
m_broadcast_axes
);
...
@@ -373,7 +373,7 @@ bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
...
@@ -373,7 +373,7 @@ bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
return
true
;
return
true
;
}
}
shared_ptr
<
Node
>
op
::
v0
::
BroadcastLike
::
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
v0
::
BroadcastLike
::
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
if
(
new_args
.
size
()
!=
2
)
{
{
...
...
src/ngraph/op/broadcast.hpp
View file @
8ee374fc
...
@@ -48,8 +48,8 @@ namespace ngraph
...
@@ -48,8 +48,8 @@ namespace ngraph
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
/// \return A set containing the indices of the broadcast axes (0-based).
/// \return A set containing the indices of the broadcast axes (0-based).
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
...
@@ -93,8 +93,8 @@ namespace ngraph
...
@@ -93,8 +93,8 @@ namespace ngraph
const
Output
<
Node
>&
like_arg
,
const
Output
<
Node
>&
like_arg
,
const
AxisSet
&
initial_broadcast_axes
);
const
AxisSet
&
initial_broadcast_axes
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
void
infer_shape
()
override
;
void
infer_shape
()
override
;
const
AxisSet
&
get_initial_broadcast_axes
()
const
const
AxisSet
&
get_initial_broadcast_axes
()
const
...
@@ -155,8 +155,8 @@ namespace ngraph
...
@@ -155,8 +155,8 @@ namespace ngraph
size_t
get_version
()
const
override
{
return
1
;
}
size_t
get_version
()
const
override
{
return
1
;
}
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
/// \return Broadcast Specification.
/// \return Broadcast Specification.
const
AutoBroadcastSpec
&
get_broadcast_spec
()
const
{
return
m_broadcast_spec
;
}
const
AutoBroadcastSpec
&
get_broadcast_spec
()
const
{
return
m_broadcast_spec
;
}
...
...
src/ngraph/op/broadcast_distributed.cpp
View file @
8ee374fc
...
@@ -48,7 +48,7 @@ void op::BroadcastDistributed::validate_and_infer_types()
...
@@ -48,7 +48,7 @@ void op::BroadcastDistributed::validate_and_infer_types()
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
}
}
shared_ptr
<
Node
>
op
::
BroadcastDistributed
::
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
BroadcastDistributed
::
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
BroadcastDistributed
>
(
new_args
.
at
(
0
),
m_root_id
);
return
make_shared
<
BroadcastDistributed
>
(
new_args
.
at
(
0
),
m_root_id
);
...
...
src/ngraph/op/broadcast_distributed.hpp
View file @
8ee374fc
...
@@ -36,8 +36,8 @@ namespace ngraph
...
@@ -36,8 +36,8 @@ namespace ngraph
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
std
::
shared_ptr
<
Node
>
c
opy_with_new_args
(
const
Node
Vector
&
new_args
)
const
override
;
c
lone_with_new_inputs
(
const
Output
Vector
&
new_args
)
const
override
;
int64_t
get_root_id
()
const
;
int64_t
get_root_id
()
const
;
void
set_root_id
(
int64_t
root_id
);
void
set_root_id
(
int64_t
root_id
);
...
...
test/copy.cpp
View file @
8ee374fc
...
@@ -88,18 +88,18 @@ TEST(copy, broadcast)
...
@@ -88,18 +88,18 @@ TEST(copy, broadcast)
{
{
Shape
shape1
{
1
};
Shape
shape1
{
1
};
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape1
);
auto
arg0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape1
);
Node
Vector
new_args
{
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape1
)};
Output
Vector
new_args
{
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape1
)};
Shape
shape
{
4
,
1
,
3
};
Shape
shape
{
4
,
1
,
3
};
AxisSet
axes
{
0
,
2
};
AxisSet
axes
{
0
,
2
};
auto
node
=
make_shared
<
op
::
Broadcast
>
(
arg0
,
shape
,
axes
);
auto
node
=
make_shared
<
op
::
Broadcast
>
(
arg0
,
shape
,
axes
);
auto
new_node
=
node
->
copy_with_new_
arg
s
(
new_args
);
auto
new_node
=
node
->
copy_with_new_
input
s
(
new_args
);
auto
node_cast
=
as_type_ptr
<
op
::
Broadcast
>
(
new_node
);
auto
node_cast
=
as_type_ptr
<
op
::
Broadcast
>
(
new_node
);
ASSERT_NE
(
node_cast
,
nullptr
);
ASSERT_NE
(
node_cast
,
nullptr
);
ASSERT_TRUE
(
nullptr
!=
new_node
);
ASSERT_TRUE
(
nullptr
!=
new_node
);
ASSERT_TRUE
(
new_args
==
new_node
->
get_argument
s
());
ASSERT_TRUE
(
new_args
==
new_node
->
input_value
s
());
ASSERT_TRUE
(
shape
==
node_cast
->
get_broadcast_shape
());
ASSERT_TRUE
(
shape
==
node_cast
->
get_broadcast_shape
());
ASSERT_TRUE
(
axes
==
node_cast
->
get_broadcast_axes
());
ASSERT_TRUE
(
axes
==
node_cast
->
get_broadcast_axes
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment