Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
17af4266
Commit
17af4266
authored
Sep 13, 2018
by
Scott Cyphers
Committed by
Robert Kimball
Sep 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix some validation errors (#1603)
parent
fe676f72
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
56 additions
and
9 deletions
+56
-9
node.cpp
src/ngraph/node.cpp
+15
-4
node.hpp
src/ngraph/node.hpp
+21
-0
broadcast.cpp
src/ngraph/op/broadcast.cpp
+15
-3
broadcast.hpp
src/ngraph/op/broadcast.hpp
+3
-0
constant.hpp
src/ngraph/op/constant.hpp
+0
-1
get_output_element.cpp
src/ngraph/op/get_output_element.cpp
+1
-0
util.cpp
src/ngraph/util.cpp
+1
-1
No files found.
src/ngraph/node.cpp
View file @
17af4266
...
@@ -201,16 +201,27 @@ namespace ngraph
...
@@ -201,16 +201,27 @@ namespace ngraph
{
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
{
{
out
<<
node
.
description
()
<<
'['
<<
node
.
get_name
()
<<
"]("
;
return
out
<<
NodeDescription
(
node
,
false
);
}
}
std
::
ostream
&
Node
::
write_short_description
(
std
::
ostream
&
out
)
const
{
return
out
<<
get_name
();
}
std
::
ostream
&
Node
::
write_long_description
(
std
::
ostream
&
out
)
const
{
out
<<
description
()
<<
'['
<<
get_name
()
<<
"]("
;
string
sep
=
""
;
string
sep
=
""
;
for
(
auto
arg
:
node
.
get_arguments
())
for
(
auto
arg
:
get_arguments
())
{
{
out
<<
sep
<<
arg
->
get_name
(
);
out
<<
sep
<<
NodeDescription
(
*
arg
,
true
);
sep
=
", "
;
sep
=
", "
;
}
}
out
<<
")"
;
out
<<
")"
;
return
out
;
return
out
;
}
}
}
size_t
Node
::
get_output_size
()
const
size_t
Node
::
get_output_size
()
const
...
...
src/ngraph/node.hpp
View file @
17af4266
...
@@ -132,6 +132,8 @@ namespace ngraph
...
@@ -132,6 +132,8 @@ namespace ngraph
virtual
bool
is_commutative
()
{
return
false
;
}
virtual
bool
is_commutative
()
{
return
false
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
virtual
std
::
ostream
&
write_short_description
(
std
::
ostream
&
)
const
;
virtual
std
::
ostream
&
write_long_description
(
std
::
ostream
&
)
const
;
// TODO: Deprecate
// TODO: Deprecate
std
::
deque
<
descriptor
::
Input
>&
get_inputs
()
{
return
m_inputs
;
}
std
::
deque
<
descriptor
::
Input
>&
get_inputs
()
{
return
m_inputs
;
}
...
@@ -253,6 +255,25 @@ namespace ngraph
...
@@ -253,6 +255,25 @@ namespace ngraph
}
}
};
};
class
NodeDescription
{
public
:
NodeDescription
(
const
Node
&
node
,
bool
is_short
)
:
m_node
(
node
)
,
m_is_short
(
is_short
)
{
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
NodeDescription
node_description
)
{
return
node_description
.
m_is_short
?
node_description
.
m_node
.
write_short_description
(
out
)
:
node_description
.
m_node
.
write_long_description
(
out
);
}
const
Node
&
m_node
;
bool
m_is_short
;
};
void
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
);
void
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
);
}
}
...
...
src/ngraph/op/broadcast.cpp
View file @
17af4266
...
@@ -85,7 +85,8 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
...
@@ -85,7 +85,8 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
op
::
BroadcastLike
::
BroadcastLike
(
const
std
::
shared_ptr
<
Node
>&
arg
,
op
::
BroadcastLike
::
BroadcastLike
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
like_arg
,
const
std
::
shared_ptr
<
Node
>&
like_arg
,
const
AxisSet
&
broadcast_axes
)
const
AxisSet
&
broadcast_axes
)
:
Broadcast
(
"BroadcastLike"
,
{
arg
,
like_arg
},
{},
broadcast_axes
)
:
Broadcast
(
"BroadcastLike"
,
{
arg
,
like_arg
},
{},
{})
,
m_initial_broadcast_axes
(
broadcast_axes
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
@@ -96,18 +97,29 @@ shared_ptr<Node> op::BroadcastLike::copy_with_new_args(const NodeVector& new_arg
...
@@ -96,18 +97,29 @@ shared_ptr<Node> op::BroadcastLike::copy_with_new_args(const NodeVector& new_arg
{
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
}
return
make_shared
<
BroadcastLike
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_broadcast_axes
);
return
make_shared
<
BroadcastLike
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_
initial_
broadcast_axes
);
}
}
void
op
::
BroadcastLike
::
infer_shape
()
void
op
::
BroadcastLike
::
infer_shape
()
{
{
const
Shape
&
in_shape
=
get_input_shape
(
0
);
const
Shape
&
in_shape
=
get_input_shape
(
0
);
m_shape
=
get_input_shape
(
1
);
m_shape
=
get_input_shape
(
1
);
m_broadcast_axes
=
m_initial_broadcast_axes
;
if
(
m_broadcast_axes
.
size
()
==
0
)
if
(
m_broadcast_axes
.
size
()
==
0
)
{
{
for
(
size_t
i
=
in_shape
.
size
();
i
<
m_shape
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
m_shape
.
size
();
++
i
)
{
if
(
i
<
in_shape
.
size
())
{
if
(
in_shape
.
at
(
i
)
==
1
&&
m_shape
.
at
(
i
)
>
1
)
{
{
m_broadcast_axes
.
insert
(
i
);
m_broadcast_axes
.
insert
(
i
);
}
}
}
}
else
{
m_broadcast_axes
.
insert
(
i
);
}
}
}
}
}
src/ngraph/op/broadcast.hpp
View file @
17af4266
...
@@ -80,6 +80,9 @@ namespace ngraph
...
@@ -80,6 +80,9 @@ namespace ngraph
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
infer_shape
()
override
;
void
infer_shape
()
override
;
protected
:
AxisSet
m_initial_broadcast_axes
;
};
};
}
}
}
}
src/ngraph/op/constant.hpp
View file @
17af4266
...
@@ -108,7 +108,6 @@ namespace ngraph
...
@@ -108,7 +108,6 @@ namespace ngraph
void
validate_and_infer_types
()
override
void
validate_and_infer_types
()
override
{
{
Node
::
validate_and_infer_types
();
infer_element_type
();
infer_element_type
();
set_output_type
(
0
,
m_element_type
,
m_shape
);
set_output_type
(
0
,
m_element_type
,
m_shape
);
}
}
...
...
src/ngraph/op/get_output_element.cpp
View file @
17af4266
...
@@ -25,6 +25,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
...
@@ -25,6 +25,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
:
Node
(
"GetOutputElement"
,
{
arg
})
:
Node
(
"GetOutputElement"
,
{
arg
})
,
m_n
{
n
}
,
m_n
{
n
}
{
{
constructor_validate_and_infer_types
();
NODE_VALIDATION_ASSERT
(
this
,
m_n
<
arg
->
get_output_size
())
NODE_VALIDATION_ASSERT
(
this
,
m_n
<
arg
->
get_output_size
())
<<
"Output at index "
<<
m_n
<<
" requested, but argument has only "
<<
"Output at index "
<<
m_n
<<
" requested, but argument has only "
<<
arg
->
get_output_size
()
<<
" outputs."
;
<<
arg
->
get_output_size
()
<<
" outputs."
;
...
...
src/ngraph/util.cpp
View file @
17af4266
...
@@ -207,7 +207,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
...
@@ -207,7 +207,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
std
::
unordered_set
<
std
::
shared_ptr
<
Node
>>
in_bprop
;
std
::
unordered_set
<
std
::
shared_ptr
<
Node
>>
in_bprop
;
ngraph
::
traverse_nodes
(
bprop
,
ngraph
::
traverse_nodes
(
bprop
,
[
&
in_bprop
](
std
::
shared_ptr
<
Node
>
node
)
{
[
&
in_bprop
](
std
::
shared_ptr
<
Node
>
node
)
{
if
(
node
->
get_output
s
().
size
()
==
1
)
if
(
node
->
get_output
_
size
()
==
1
)
{
{
if
(
in_bprop
.
count
(
node
)
==
0
)
if
(
in_bprop
.
count
(
node
)
==
0
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment