Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e1ad1900
Commit
e1ad1900
authored
Dec 06, 2019
by
Scott Cyphers
Committed by
Sang Ik Lee
Dec 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add additional exports (#4006)
* Add exports * Work-around windows issues * windows * Avoid vectors
parent
b231ccc3
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
213 additions
and
27 deletions
+213
-27
attribute_visitor.hpp
src/ngraph/attribute_visitor.hpp
+1
-1
node.hpp
src/ngraph/node.hpp
+212
-26
No files found.
src/ngraph/attribute_visitor.hpp
View file @
e1ad1900
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
/// Attributes are the values set when building a graph which are not
/// Attributes are the values set when building a graph which are not
/// computed as the graph executes. Values computed from the graph topology and attributes
/// computed as the graph executes. Values computed from the graph topology and attributes
/// during compilation are not attributes.
/// during compilation are not attributes.
class
AttributeVisitor
class
NGRAPH_API
AttributeVisitor
{
{
public
:
public
:
virtual
~
AttributeVisitor
()
{}
virtual
~
AttributeVisitor
()
{}
...
...
src/ngraph/node.hpp
View file @
e1ad1900
...
@@ -531,22 +531,32 @@ namespace ngraph
...
@@ -531,22 +531,32 @@ namespace ngraph
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
Variant
>>
m_rt_info
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
Variant
>>
m_rt_info
;
};
};
/// \brief A handle for one of a node's inputs.
template
<
typename
NodeType
>
template
<
typename
NodeType
>
class
Input
class
Input
{
{
};
template
<
typename
NodeType
>
class
Output
{
};
/// \brief A handle for one of a node's inputs.
template
<>
class
Input
<
Node
>
{
public
:
public
:
/// \brief Constructs a Input.
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
/// \param index The index of the input.
Input
(
Node
Type
*
node
,
size_t
index
)
Input
(
Node
*
node
,
size_t
index
)
:
m_node
(
node
)
:
m_node
(
node
)
,
m_index
(
index
)
,
m_index
(
index
)
{
{
}
}
/// \return A pointer to the node referenced by this input handle.
/// \return A pointer to the node referenced by this input handle.
Node
Type
*
get_node
()
const
{
return
m_node
;
}
Node
*
get_node
()
const
{
return
m_node
;
}
/// \return The index of the input referred to by this input handle.
/// \return The index of the input referred to by this input handle.
size_t
get_index
()
const
{
return
m_index
;
}
size_t
get_index
()
const
{
return
m_index
;
}
/// \return The element type of the input referred to by this input handle.
/// \return The element type of the input referred to by this input handle.
...
@@ -604,19 +614,92 @@ namespace ngraph
...
@@ -604,19 +614,92 @@ namespace ngraph
bool
operator
<=
(
const
Input
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
<=
(
const
Input
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
>=
(
const
Input
&
other
)
const
{
return
!
(
*
this
<
other
);
}
bool
operator
>=
(
const
Input
&
other
)
const
{
return
!
(
*
this
<
other
);
}
private
:
private
:
NodeType
*
const
m_node
;
Node
*
const
m_node
;
const
size_t
m_index
;
};
/// \brief A handle for one of a node's inputs.
template
<>
class
NGRAPH_API
Input
<
const
Node
>
{
public
:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input
(
const
Node
*
node
,
size_t
index
)
:
m_node
(
node
)
,
m_index
(
index
)
{
}
/// \return A pointer to the node referenced by this input handle.
const
Node
*
get_node
()
const
{
return
m_node
;
}
/// \return The index of the input referred to by this input handle.
size_t
get_index
()
const
{
return
m_index
;
}
/// \return The element type of the input referred to by this input handle.
const
element
::
Type
&
get_element_type
()
const
{
return
m_node
->
get_input_element_type
(
m_index
);
}
/// \return The shape of the input referred to by this input handle.
const
Shape
&
get_shape
()
const
{
return
m_node
->
get_input_shape
(
m_index
);
}
/// \return The partial shape of the input referred to by this input handle.
const
PartialShape
&
get_partial_shape
()
const
{
return
m_node
->
get_input_partial_shape
(
m_index
);
}
/// \return A handle to the output that is connected to this input.
Output
<
Node
>
get_source_output
()
const
;
/// \return A reference to the tensor descriptor for this input.
descriptor
::
Tensor
&
get_tensor
()
const
{
return
m_node
->
m_inputs
.
at
(
m_index
).
get_output
().
get_tensor
();
}
/// \return A shared pointer to the tensor descriptor for this input.
std
::
shared_ptr
<
descriptor
::
Tensor
>
get_tensor_ptr
()
const
{
return
m_node
->
m_inputs
.
at
(
m_index
).
get_output
().
get_tensor_ptr
();
}
/// \return true if this input is relevant to its node's output shapes; else false.
bool
get_is_relevant_to_shapes
()
const
{
return
m_node
->
m_inputs
.
at
(
m_index
).
get_is_relevant_to_shape
();
}
/// \return true if this input is relevant to its node's output values; else false.
bool
get_is_relevant_to_values
()
const
{
return
m_node
->
m_inputs
.
at
(
m_index
).
get_is_relevant_to_value
();
}
bool
operator
==
(
const
Input
&
other
)
const
{
return
m_node
==
other
.
m_node
&&
m_index
==
other
.
m_index
;
}
bool
operator
!=
(
const
Input
&
other
)
const
{
return
!
(
*
this
==
other
);
}
bool
operator
<
(
const
Input
&
other
)
const
{
return
m_node
<
other
.
m_node
||
(
m_node
==
other
.
m_node
&&
m_index
<
other
.
m_index
);
}
bool
operator
>
(
const
Input
&
other
)
const
{
return
m_node
>
other
.
m_node
||
(
m_node
==
other
.
m_node
&&
m_index
>
other
.
m_index
);
}
bool
operator
<=
(
const
Input
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
>=
(
const
Input
&
other
)
const
{
return
!
(
*
this
<
other
);
}
private
:
const
Node
*
const
m_node
;
const
size_t
m_index
;
const
size_t
m_index
;
};
};
/// \brief A handle for one of a node's outputs.
/// \brief A handle for one of a node's outputs.
template
<
typename
NodeType
=
Node
>
template
<>
class
Output
class
NGRAPH_API
Output
<
Node
>
{
{
public
:
public
:
/// \brief Constructs a Output.
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
/// \param index The index of the output.
Output
(
Node
Type
*
node
,
size_t
index
)
Output
(
Node
*
node
,
size_t
index
)
:
m_node
(
node
->
shared_from_this
())
:
m_node
(
node
->
shared_from_this
())
,
m_index
(
index
)
,
m_index
(
index
)
{
{
...
@@ -627,7 +710,7 @@ namespace ngraph
...
@@ -627,7 +710,7 @@ namespace ngraph
/// \param index The index of the output.
/// \param index The index of the output.
///
///
/// TODO: Make a plan to deprecate this.
/// TODO: Make a plan to deprecate this.
Output
(
const
std
::
shared_ptr
<
Node
Type
>&
node
,
size_t
index
)
Output
(
const
std
::
shared_ptr
<
Node
>&
node
,
size_t
index
)
:
m_node
(
node
)
:
m_node
(
node
)
,
m_index
(
index
)
,
m_index
(
index
)
{
{
...
@@ -645,17 +728,13 @@ namespace ngraph
...
@@ -645,17 +728,13 @@ namespace ngraph
Output
()
=
default
;
Output
()
=
default
;
/// This output position for a different node
/// This output position for a different node
Output
<
NodeType
>
for_node
(
const
std
::
shared_ptr
<
NodeType
>&
node
)
Output
<
Node
>
for_node
(
const
std
::
shared_ptr
<
Node
>&
node
)
{
return
Output
(
node
,
m_index
);
}
{
return
Output
(
node
,
m_index
);
}
/// \return A pointer to the node referred to by this output handle.
/// \return A pointer to the node referred to by this output handle.
Node
Type
*
get_node
()
const
{
return
m_node
.
get
();
}
Node
*
get_node
()
const
{
return
m_node
.
get
();
}
/// \return A `shared_ptr` to the node referred to by this output handle.
/// \return A `shared_ptr` to the node referred to by this output handle.
///
///
/// TODO: Make a plan to deprecate this.
/// TODO: Make a plan to deprecate this.
std
::
shared_ptr
<
Node
Type
>
get_node_shared_ptr
()
const
{
return
m_node
;
}
std
::
shared_ptr
<
Node
>
get_node_shared_ptr
()
const
{
return
m_node
;
}
/// \return A useable shared pointer to this output. If index 0, the node,
/// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE.
/// otherwise find or create a GOE.
std
::
shared_ptr
<
Node
>
as_single_output_node
(
bool
for_get_output_element
=
true
)
const
std
::
shared_ptr
<
Node
>
as_single_output_node
(
bool
for_get_output_element
=
true
)
const
...
@@ -715,12 +794,105 @@ namespace ngraph
...
@@ -715,12 +794,105 @@ namespace ngraph
bool
operator
<=
(
const
Output
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
<=
(
const
Output
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
>=
(
const
Output
&
other
)
const
{
return
!
(
*
this
<
other
);
}
bool
operator
>=
(
const
Output
&
other
)
const
{
return
!
(
*
this
<
other
);
}
private
:
private
:
std
::
shared_ptr
<
Node
Type
>
m_node
;
std
::
shared_ptr
<
Node
>
m_node
;
size_t
m_index
{
0
};
size_t
m_index
{
0
};
};
};
template
class
NGRAPH_API
Input
<
Node
>
;
template
<>
template
class
NGRAPH_API
Output
<
Node
>
;
class
NGRAPH_API
Output
<
const
Node
>
{
public
:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output
(
const
Node
*
node
,
size_t
index
)
:
m_node
(
node
->
shared_from_this
())
,
m_index
(
index
)
{
}
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output
(
const
std
::
shared_ptr
<
const
Node
>&
node
,
size_t
index
)
:
m_node
(
node
)
,
m_index
(
index
)
{
}
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template
<
typename
T
>
Output
(
const
std
::
shared_ptr
<
T
>&
node
)
:
Output
(
node
,
0
)
{
}
/// A null output
Output
()
=
default
;
/// This output position for a different node
Output
<
const
Node
>
for_node
(
const
std
::
shared_ptr
<
const
Node
>&
node
)
{
return
Output
(
node
,
m_index
);
}
/// \return A pointer to the node referred to by this output handle.
const
Node
*
get_node
()
const
{
return
m_node
.
get
();
}
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std
::
shared_ptr
<
const
Node
>
get_node_shared_ptr
()
const
{
return
m_node
;
}
/// \return The index of the output referred to by this output handle.
size_t
get_index
()
const
{
return
m_index
;
}
/// \return A reference to the tensor descriptor for this output.
descriptor
::
Tensor
&
get_tensor
()
const
{
return
m_node
->
m_outputs
.
at
(
m_index
).
get_tensor
();
}
/// \return A shared point to the tensor ptr for this output.
std
::
shared_ptr
<
descriptor
::
Tensor
>
get_tensor_ptr
()
const
{
return
m_node
->
m_outputs
.
at
(
m_index
).
get_tensor_ptr
();
}
/// \return The element type of the output referred to by this output handle.
const
element
::
Type
&
get_element_type
()
const
{
return
m_node
->
get_output_element_type
(
m_index
);
}
/// \return The shape of the output referred to by this output handle.
const
Shape
&
get_shape
()
const
{
return
m_node
->
get_output_shape
(
m_index
);
}
/// \return The partial shape of the output referred to by this output handle.
const
PartialShape
&
get_partial_shape
()
const
{
return
m_node
->
get_output_partial_shape
(
m_index
);
}
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std
::
set
<
Input
<
Node
>>
get_target_inputs
()
const
;
bool
operator
==
(
const
Output
&
other
)
const
{
return
m_node
==
other
.
m_node
&&
m_index
==
other
.
m_index
;
}
bool
operator
!=
(
const
Output
&
other
)
const
{
return
!
(
*
this
==
other
);
}
bool
operator
<
(
const
Output
&
other
)
const
{
return
m_node
<
other
.
m_node
||
(
m_node
==
other
.
m_node
&&
m_index
<
other
.
m_index
);
}
bool
operator
>
(
const
Output
&
other
)
const
{
return
m_node
>
other
.
m_node
||
(
m_node
==
other
.
m_node
&&
m_index
>
other
.
m_index
);
}
bool
operator
<=
(
const
Output
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
>=
(
const
Output
&
other
)
const
{
return
!
(
*
this
<
other
);
}
private
:
std
::
shared_ptr
<
const
Node
>
m_node
;
size_t
m_index
{
0
};
};
inline
Input
<
Node
>
Node
::
input
(
size_t
input_index
)
inline
Input
<
Node
>
Node
::
input
(
size_t
input_index
)
{
{
...
@@ -767,22 +939,25 @@ namespace ngraph
...
@@ -767,22 +939,25 @@ namespace ngraph
return
Output
<
const
Node
>
(
this
,
output_index
);
return
Output
<
const
Node
>
(
this
,
output_index
);
}
}
template
<
typename
NodeType
>
inline
Output
<
Node
>
Input
<
Node
>::
get_source_output
()
const
Output
<
Node
>
Input
<
NodeType
>::
get_source_output
()
const
{
{
auto
&
output_descriptor
=
m_node
->
m_inputs
.
at
(
m_index
).
get_output
();
auto
&
output_descriptor
=
m_node
->
m_inputs
.
at
(
m_index
).
get_output
();
return
Output
<
Node
>
(
output_descriptor
.
get_node
(),
output_descriptor
.
get_index
());
return
Output
<
Node
>
(
output_descriptor
.
get_node
(),
output_descriptor
.
get_index
());
}
}
template
<
typename
NodeType
>
inline
Output
<
Node
>
Input
<
const
Node
>::
get_source_output
()
const
void
Input
<
NodeType
>::
replace_source_output
(
const
Output
<
Node
>&
new_source_output
)
const
{
auto
&
output_descriptor
=
m_node
->
m_inputs
.
at
(
m_index
).
get_output
();
return
Output
<
Node
>
(
output_descriptor
.
get_node
(),
output_descriptor
.
get_index
());
}
inline
void
Input
<
Node
>::
replace_source_output
(
const
Output
<
Node
>&
new_source_output
)
const
{
{
m_node
->
m_inputs
.
at
(
m_index
).
replace_output
(
new_source_output
.
get_node_shared_ptr
(),
m_node
->
m_inputs
.
at
(
m_index
).
replace_output
(
new_source_output
.
get_node_shared_ptr
(),
new_source_output
.
get_index
());
new_source_output
.
get_index
());
}
}
template
<
typename
NodeType
>
inline
std
::
set
<
Input
<
Node
>>
Output
<
Node
>::
get_target_inputs
()
const
std
::
set
<
Input
<
Node
>>
Output
<
NodeType
>::
get_target_inputs
()
const
{
{
std
::
set
<
Input
<
Node
>>
result
;
std
::
set
<
Input
<
Node
>>
result
;
...
@@ -794,8 +969,19 @@ namespace ngraph
...
@@ -794,8 +969,19 @@ namespace ngraph
return
result
;
return
result
;
}
}
template
<
typename
NodeType
>
inline
std
::
set
<
Input
<
Node
>>
Output
<
const
Node
>::
get_target_inputs
()
const
void
Output
<
NodeType
>::
remove_target_input
(
const
Input
<
Node
>&
target_input
)
const
{
std
::
set
<
Input
<
Node
>>
result
;
for
(
auto
&
input
:
m_node
->
m_outputs
.
at
(
m_index
).
get_inputs
())
{
result
.
emplace
(
input
->
get_raw_pointer_node
(),
input
->
get_index
());
}
return
result
;
}
inline
void
Output
<
Node
>::
remove_target_input
(
const
Input
<
Node
>&
target_input
)
const
{
{
m_node
->
m_outputs
.
at
(
m_index
).
remove_input
(
m_node
->
m_outputs
.
at
(
m_index
).
remove_input
(
&
(
target_input
.
get_node
()
->
m_inputs
.
at
(
target_input
.
get_index
())));
&
(
target_input
.
get_node
()
->
m_inputs
.
at
(
target_input
.
get_index
())));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment