Commit e1ad1900 authored by Scott Cyphers's avatar Scott Cyphers Committed by Sang Ik Lee

Add additional exports (#4006)

* Add exports

* Work-around windows issues

* windows

* Avoid vectors
parent b231ccc3
......@@ -33,7 +33,7 @@ namespace ngraph
/// Attributes are the values set when building a graph which are not
/// computed as the graph executes. Values computed from the graph topology and attributes
/// during compilation are not attributes.
class AttributeVisitor
class NGRAPH_API AttributeVisitor
{
public:
virtual ~AttributeVisitor() {}
......
......@@ -531,22 +531,32 @@ namespace ngraph
std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
};
/// \brief A handle for one of a node's inputs.
template <typename NodeType>
class Input
{
};
template <typename NodeType>
class Output
{
};
/// \brief A handle for one of a node's inputs.
template <>
class Input<Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(NodeType* node, size_t index)
Input(Node* node, size_t index)
: m_node(node)
, m_index(index)
{
}
/// \return A pointer to the node referenced by this input handle.
NodeType* get_node() const { return m_node; }
Node* get_node() const { return m_node; }
/// \return The index of the input referred to by this input handle.
size_t get_index() const { return m_index; }
/// \return The element type of the input referred to by this input handle.
......@@ -604,19 +614,92 @@ namespace ngraph
bool operator<=(const Input& other) const { return !(*this > other); }
bool operator>=(const Input& other) const { return !(*this < other); }
private:
NodeType* const m_node;
Node* const m_node;
const size_t m_index;
};
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<const Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(const Node* node, size_t index)
: m_node(node)
, m_index(index)
{
}
/// \return A pointer to the node referenced by this input handle.
const Node* get_node() const { return m_node; }
/// \return The index of the input referred to by this input handle.
size_t get_index() const { return m_index; }
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const
{
return m_node->get_input_element_type(m_index);
}
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const { return m_node->get_input_shape(m_index); }
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const
{
return m_node->get_input_partial_shape(m_index);
}
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor();
}
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor_ptr();
}
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_shape();
}
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_value();
}
bool operator==(const Input& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool operator!=(const Input& other) const { return !(*this == other); }
bool operator<(const Input& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool operator>(const Input& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool operator<=(const Input& other) const { return !(*this > other); }
bool operator>=(const Input& other) const { return !(*this < other); }
private:
const Node* const m_node;
const size_t m_index;
};
/// \brief A handle for one of a node's outputs.
template <typename NodeType = Node>
class Output
template <>
class NGRAPH_API Output<Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(NodeType* node, size_t index)
Output(Node* node, size_t index)
: m_node(node->shared_from_this())
, m_index(index)
{
......@@ -627,7 +710,7 @@ namespace ngraph
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<NodeType>& node, size_t index)
Output(const std::shared_ptr<Node>& node, size_t index)
: m_node(node)
, m_index(index)
{
......@@ -645,17 +728,13 @@ namespace ngraph
Output() = default;
/// This output position for a different node
Output<NodeType> for_node(const std::shared_ptr<NodeType>& node)
{
return Output(node, m_index);
}
Output<Node> for_node(const std::shared_ptr<Node>& node) { return Output(node, m_index); }
/// \return A pointer to the node referred to by this output handle.
NodeType* get_node() const { return m_node.get(); }
Node* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
std::shared_ptr<Node> get_node_shared_ptr() const { return m_node; }
/// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const
......@@ -715,12 +794,105 @@ namespace ngraph
bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); }
private:
std::shared_ptr<NodeType> m_node;
std::shared_ptr<Node> m_node;
size_t m_index{0};
};
template class NGRAPH_API Input<Node>;
template class NGRAPH_API Output<Node>;
template <>
class NGRAPH_API Output<const Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(const Node* node, size_t index)
: m_node(node->shared_from_this())
, m_index(index)
{
}
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<const Node>& node, size_t index)
: m_node(node)
, m_index(index)
{
}
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node)
: Output(node, 0)
{
}
/// A null output
Output() = default;
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node)
{
return Output(node, m_index);
}
/// \return A pointer to the node referred to by this output handle.
const Node* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<const Node> get_node_shared_ptr() const { return m_node; }
/// \return The index of the output referred to by this output handle.
size_t get_index() const { return m_index; }
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const
{
return m_node->m_outputs.at(m_index).get_tensor();
}
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_outputs.at(m_index).get_tensor_ptr();
}
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const
{
return m_node->get_output_element_type(m_index);
}
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const { return m_node->get_output_shape(m_index); }
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const
{
return m_node->get_output_partial_shape(m_index);
}
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
bool operator==(const Output& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool operator!=(const Output& other) const { return !(*this == other); }
bool operator<(const Output& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool operator>(const Output& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); }
private:
std::shared_ptr<const Node> m_node;
size_t m_index{0};
};
inline Input<Node> Node::input(size_t input_index)
{
......@@ -767,22 +939,25 @@ namespace ngraph
return Output<const Node>(this, output_index);
}
template <typename NodeType>
Output<Node> Input<NodeType>::get_source_output() const
inline Output<Node> Input<Node>::get_source_output() const
{
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index());
}
template <typename NodeType>
void Input<NodeType>::replace_source_output(const Output<Node>& new_source_output) const
inline Output<Node> Input<const Node>::get_source_output() const
{
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index());
}
inline void Input<Node>::replace_source_output(const Output<Node>& new_source_output) const
{
m_node->m_inputs.at(m_index).replace_output(new_source_output.get_node_shared_ptr(),
new_source_output.get_index());
}
template <typename NodeType>
std::set<Input<Node>> Output<NodeType>::get_target_inputs() const
inline std::set<Input<Node>> Output<Node>::get_target_inputs() const
{
std::set<Input<Node>> result;
......@@ -794,8 +969,19 @@ namespace ngraph
return result;
}
template <typename NodeType>
void Output<NodeType>::remove_target_input(const Input<Node>& target_input) const
inline std::set<Input<Node>> Output<const Node>::get_target_inputs() const
{
std::set<Input<Node>> result;
for (auto& input : m_node->m_outputs.at(m_index).get_inputs())
{
result.emplace(input->get_raw_pointer_node(), input->get_index());
}
return result;
}
inline void Output<Node>::remove_target_input(const Input<Node>& target_input) const
{
m_node->m_outputs.at(m_index).remove_input(
&(target_input.get_node()->m_inputs.at(target_input.get_index())));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment