Commit e1ad1900 authored by Scott Cyphers's avatar Scott Cyphers Committed by Sang Ik Lee

Add additional exports (#4006)

* Add exports

* Work-around windows issues

* windows

* Avoid vectors
parent b231ccc3
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
/// Attributes are the values set when building a graph which are not /// Attributes are the values set when building a graph which are not
/// computed as the graph executes. Values computed from the graph topology and attributes /// computed as the graph executes. Values computed from the graph topology and attributes
/// during compilation are not attributes. /// during compilation are not attributes.
class AttributeVisitor class NGRAPH_API AttributeVisitor
{ {
public: public:
virtual ~AttributeVisitor() {} virtual ~AttributeVisitor() {}
......
...@@ -531,22 +531,32 @@ namespace ngraph ...@@ -531,22 +531,32 @@ namespace ngraph
std::map<std::string, std::shared_ptr<Variant>> m_rt_info; std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
}; };
/// \brief A handle for one of a node's inputs.
template <typename NodeType> template <typename NodeType>
class Input class Input
{ {
};
template <typename NodeType>
class Output
{
};
/// \brief A handle for one of a node's inputs.
template <>
class Input<Node>
{
public: public:
/// \brief Constructs a Input. /// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle. /// \param node Pointer to the node for the input handle.
/// \param index The index of the input. /// \param index The index of the input.
Input(NodeType* node, size_t index) Input(Node* node, size_t index)
: m_node(node) : m_node(node)
, m_index(index) , m_index(index)
{ {
} }
/// \return A pointer to the node referenced by this input handle. /// \return A pointer to the node referenced by this input handle.
NodeType* get_node() const { return m_node; } Node* get_node() const { return m_node; }
/// \return The index of the input referred to by this input handle. /// \return The index of the input referred to by this input handle.
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
/// \return The element type of the input referred to by this input handle. /// \return The element type of the input referred to by this input handle.
...@@ -604,19 +614,92 @@ namespace ngraph ...@@ -604,19 +614,92 @@ namespace ngraph
bool operator<=(const Input& other) const { return !(*this > other); } bool operator<=(const Input& other) const { return !(*this > other); }
bool operator>=(const Input& other) const { return !(*this < other); } bool operator>=(const Input& other) const { return !(*this < other); }
private: private:
NodeType* const m_node; Node* const m_node;
const size_t m_index;
};
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<const Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(const Node* node, size_t index)
: m_node(node)
, m_index(index)
{
}
/// \return A pointer to the node referenced by this input handle.
const Node* get_node() const { return m_node; }
/// \return The index of the input referred to by this input handle.
size_t get_index() const { return m_index; }
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const
{
return m_node->get_input_element_type(m_index);
}
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const { return m_node->get_input_shape(m_index); }
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const
{
return m_node->get_input_partial_shape(m_index);
}
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor();
}
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor_ptr();
}
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_shape();
}
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_value();
}
bool operator==(const Input& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool operator!=(const Input& other) const { return !(*this == other); }
bool operator<(const Input& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool operator>(const Input& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool operator<=(const Input& other) const { return !(*this > other); }
bool operator>=(const Input& other) const { return !(*this < other); }
private:
const Node* const m_node;
const size_t m_index; const size_t m_index;
}; };
/// \brief A handle for one of a node's outputs. /// \brief A handle for one of a node's outputs.
template <typename NodeType = Node> template <>
class Output class NGRAPH_API Output<Node>
{ {
public: public:
/// \brief Constructs a Output. /// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle. /// \param node A pointer to the node for the output handle.
/// \param index The index of the output. /// \param index The index of the output.
Output(NodeType* node, size_t index) Output(Node* node, size_t index)
: m_node(node->shared_from_this()) : m_node(node->shared_from_this())
, m_index(index) , m_index(index)
{ {
...@@ -627,7 +710,7 @@ namespace ngraph ...@@ -627,7 +710,7 @@ namespace ngraph
/// \param index The index of the output. /// \param index The index of the output.
/// ///
/// TODO: Make a plan to deprecate this. /// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<NodeType>& node, size_t index) Output(const std::shared_ptr<Node>& node, size_t index)
: m_node(node) : m_node(node)
, m_index(index) , m_index(index)
{ {
...@@ -645,17 +728,13 @@ namespace ngraph ...@@ -645,17 +728,13 @@ namespace ngraph
Output() = default; Output() = default;
/// This output position for a different node /// This output position for a different node
Output<NodeType> for_node(const std::shared_ptr<NodeType>& node) Output<Node> for_node(const std::shared_ptr<Node>& node) { return Output(node, m_index); }
{
return Output(node, m_index);
}
/// \return A pointer to the node referred to by this output handle. /// \return A pointer to the node referred to by this output handle.
NodeType* get_node() const { return m_node.get(); } Node* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle. /// \return A `shared_ptr` to the node referred to by this output handle.
/// ///
/// TODO: Make a plan to deprecate this. /// TODO: Make a plan to deprecate this.
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; } std::shared_ptr<Node> get_node_shared_ptr() const { return m_node; }
/// \return A useable shared pointer to this output. If index 0, the node, /// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE. /// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const
...@@ -715,12 +794,105 @@ namespace ngraph ...@@ -715,12 +794,105 @@ namespace ngraph
bool operator<=(const Output& other) const { return !(*this > other); } bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); } bool operator>=(const Output& other) const { return !(*this < other); }
private: private:
std::shared_ptr<NodeType> m_node; std::shared_ptr<Node> m_node;
size_t m_index{0}; size_t m_index{0};
}; };
template class NGRAPH_API Input<Node>; template <>
template class NGRAPH_API Output<Node>; class NGRAPH_API Output<const Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(const Node* node, size_t index)
: m_node(node->shared_from_this())
, m_index(index)
{
}
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<const Node>& node, size_t index)
: m_node(node)
, m_index(index)
{
}
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node)
: Output(node, 0)
{
}
/// A null output
Output() = default;
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node)
{
return Output(node, m_index);
}
/// \return A pointer to the node referred to by this output handle.
const Node* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<const Node> get_node_shared_ptr() const { return m_node; }
/// \return The index of the output referred to by this output handle.
size_t get_index() const { return m_index; }
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const
{
return m_node->m_outputs.at(m_index).get_tensor();
}
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_outputs.at(m_index).get_tensor_ptr();
}
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const
{
return m_node->get_output_element_type(m_index);
}
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const { return m_node->get_output_shape(m_index); }
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const
{
return m_node->get_output_partial_shape(m_index);
}
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
bool operator==(const Output& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool operator!=(const Output& other) const { return !(*this == other); }
bool operator<(const Output& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool operator>(const Output& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); }
private:
std::shared_ptr<const Node> m_node;
size_t m_index{0};
};
inline Input<Node> Node::input(size_t input_index) inline Input<Node> Node::input(size_t input_index)
{ {
...@@ -767,22 +939,25 @@ namespace ngraph ...@@ -767,22 +939,25 @@ namespace ngraph
return Output<const Node>(this, output_index); return Output<const Node>(this, output_index);
} }
template <typename NodeType> inline Output<Node> Input<Node>::get_source_output() const
Output<Node> Input<NodeType>::get_source_output() const
{ {
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output(); auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index()); return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index());
} }
template <typename NodeType> inline Output<Node> Input<const Node>::get_source_output() const
void Input<NodeType>::replace_source_output(const Output<Node>& new_source_output) const {
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index());
}
inline void Input<Node>::replace_source_output(const Output<Node>& new_source_output) const
{ {
m_node->m_inputs.at(m_index).replace_output(new_source_output.get_node_shared_ptr(), m_node->m_inputs.at(m_index).replace_output(new_source_output.get_node_shared_ptr(),
new_source_output.get_index()); new_source_output.get_index());
} }
template <typename NodeType> inline std::set<Input<Node>> Output<Node>::get_target_inputs() const
std::set<Input<Node>> Output<NodeType>::get_target_inputs() const
{ {
std::set<Input<Node>> result; std::set<Input<Node>> result;
...@@ -794,8 +969,19 @@ namespace ngraph ...@@ -794,8 +969,19 @@ namespace ngraph
return result; return result;
} }
template <typename NodeType> inline std::set<Input<Node>> Output<const Node>::get_target_inputs() const
void Output<NodeType>::remove_target_input(const Input<Node>& target_input) const {
std::set<Input<Node>> result;
for (auto& input : m_node->m_outputs.at(m_index).get_inputs())
{
result.emplace(input->get_raw_pointer_node(), input->get_index());
}
return result;
}
inline void Output<Node>::remove_target_input(const Input<Node>& target_input) const
{ {
m_node->m_outputs.at(m_index).remove_input( m_node->m_outputs.at(m_index).remove_input(
&(target_input.get_node()->m_inputs.at(target_input.get_index()))); &(target_input.get_node()->m_inputs.at(target_input.get_index())));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment