Unverified Commit 2f44b758 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Move Input/Output from node.?pp to separate files (#4305)

* Start

* Move Input

* Move Output

* Move ostream
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 6d299dab
......@@ -98,6 +98,10 @@ set (SRC
ngraph.cpp
ngraph.hpp
ngraph_visibility.hpp
node_input.cpp
node_input.hpp
node_output.cpp
node_output.hpp
node.cpp
node.hpp
op/abs.cpp
......
......@@ -991,49 +991,6 @@ bool Node::is_dynamic() const
return false;
}
namespace ngraph
{
std::ostream& operator<<(std::ostream& out, const Output<Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Output<const Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<const Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
void Output<Node>::replace(const Output<Node>& replacement)
{
for (auto& input : get_target_inputs())
{
// GOEs are used as handles in passes
if (!is_type<op::GetOutputElement>(input.get_node()))
{
input.replace_source_output(replacement);
}
}
}
}
Input<Node> Node::input(size_t input_index)
{
if (input_index >= m_inputs.size())
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/node_input.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
Input<Node>::Input(Node* node, size_t index)
: m_node(node)
, m_index(index)
{
}
Node* Input<Node>::get_node() const { return m_node; }
size_t Input<Node>::get_index() const { return m_index; }
const element::Type& Input<Node>::get_element_type() const
{
return m_node->get_input_element_type(m_index);
}
const Shape& Input<Node>::get_shape() const { return m_node->get_input_shape(m_index); }
const PartialShape& Input<Node>::get_partial_shape() const
{
return m_node->get_input_partial_shape(m_index);
}
Output<Node> Input<Node>::get_source_output() const
{
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index());
}
descriptor::Tensor& Input<Node>::get_tensor() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor();
}
std::shared_ptr<descriptor::Tensor> Input<Node>::get_tensor_ptr() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor_ptr();
}
bool Input<Node>::get_is_relevant_to_shapes() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_shape();
}
bool Input<Node>::get_is_relevant_to_values() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_value();
}
void Input<Node>::replace_source_output(const Output<Node>& new_source_output) const
{
m_node->m_inputs.at(m_index).replace_output(new_source_output.get_node_shared_ptr(),
new_source_output.get_index());
}
bool Input<Node>::operator==(const Input& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool Input<Node>::operator!=(const Input& other) const { return !(*this == other); }
bool Input<Node>::operator<(const Input& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool Input<Node>::operator>(const Input& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool Input<Node>::operator<=(const Input& other) const { return !(*this > other); }
bool Input<Node>::operator>=(const Input& other) const { return !(*this < other); }
Input<const Node>::Input(const Node* node, size_t index)
: m_node(node)
, m_index(index)
{
}
const Node* Input<const Node>::get_node() const { return m_node; }
size_t Input<const Node>::get_index() const { return m_index; }
const element::Type& Input<const Node>::get_element_type() const
{
return m_node->get_input_element_type(m_index);
}
const Shape& Input<const Node>::get_shape() const { return m_node->get_input_shape(m_index); }
const PartialShape& Input<const Node>::get_partial_shape() const
{
return m_node->get_input_partial_shape(m_index);
}
Output<Node> Input<const Node>::get_source_output() const
{
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
return Output<Node>(output_descriptor.get_node(), output_descriptor.get_index());
}
descriptor::Tensor& Input<const Node>::get_tensor() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor();
}
std::shared_ptr<descriptor::Tensor> Input<const Node>::get_tensor_ptr() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor_ptr();
}
bool Input<const Node>::get_is_relevant_to_shapes() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_shape();
}
bool Input<const Node>::get_is_relevant_to_values() const
{
return m_node->m_inputs.at(m_index).get_is_relevant_to_value();
}
bool Input<const Node>::operator==(const Input& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool Input<const Node>::operator!=(const Input& other) const { return !(*this == other); }
bool Input<const Node>::operator<(const Input& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool Input<const Node>::operator>(const Input& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool Input<const Node>::operator<=(const Input& other) const { return !(*this > other); }
bool Input<const Node>::operator>=(const Input& other) const { return !(*this < other); }
std::ostream& operator<<(std::ostream& out, const Input<Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<const Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstring>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
class Node;
template <typename NodeType>
class Output;
template <typename NodeType>
class Input
{
};
/// \brief A handle for one of a node's inputs.
template <>
class Input<Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(Node* node, size_t index);
/// \return A pointer to the node referenced by this input handle.
Node* get_node() const;
/// \return The index of the input referred to by this input handle.
size_t get_index() const;
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const;
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const;
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const;
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const;
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const;
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
/// \brief Replaces the source output of this input.
/// \param new_source_output A handle for the output that will replace this input's source.
void replace_source_output(const Output<Node>& new_source_output) const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
bool operator>(const Input& other) const;
bool operator<=(const Input& other) const;
bool operator>=(const Input& other) const;
private:
Node* const m_node;
const size_t m_index;
};
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<const Node>
{
public:
/// \brief Constructs a Input.
/// \param node Pointer to the node for the input handle.
/// \param index The index of the input.
Input(const Node* node, size_t index);
/// \return A pointer to the node referenced by this input handle.
const Node* get_node() const;
/// \return The index of the input referred to by this input handle.
size_t get_index() const;
/// \return The element type of the input referred to by this input handle.
const element::Type& get_element_type() const;
/// \return The shape of the input referred to by this input handle.
const Shape& get_shape() const;
/// \return The partial shape of the input referred to by this input handle.
const PartialShape& get_partial_shape() const;
/// \return A handle to the output that is connected to this input.
Output<Node> get_source_output() const;
/// \return A reference to the tensor descriptor for this input.
descriptor::Tensor& get_tensor() const;
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const;
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
bool operator>(const Input& other) const;
bool operator<=(const Input& other) const;
bool operator>=(const Input& other) const;
private:
const Node* const m_node;
const size_t m_index;
};
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<Node>& input);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<const Node>& input);
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/node_output.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp"
namespace ngraph
{
Output<Node>::Output(Node* node, size_t index)
: m_node(node->shared_from_this())
, m_index(index)
{
}
Output<Node>::Output(const std::shared_ptr<Node>& node, size_t index)
: m_node(node)
, m_index(index)
{
}
void Output<Node>::reset()
{
m_node.reset();
m_index = 0;
}
Output<Node> Output<Node>::for_node(const std::shared_ptr<Node>& node)
{
return Output(node, m_index);
}
Node* Output<Node>::get_node() const { return m_node.get(); }
std::shared_ptr<Node> Output<Node>::get_node_shared_ptr() const { return m_node; }
std::shared_ptr<Node> Output<Node>::as_single_output_node(bool for_get_output_element) const
{
return m_node->get_output_as_single_output_node(m_index, for_get_output_element);
}
size_t Output<Node>::get_index() const { return m_index; }
descriptor::Tensor& Output<Node>::get_tensor() const
{
return m_node->m_outputs.at(m_index).get_tensor();
}
std::shared_ptr<descriptor::Tensor> Output<Node>::get_tensor_ptr() const
{
return m_node->m_outputs.at(m_index).get_tensor_ptr();
}
const element::Type& Output<Node>::get_element_type() const
{
return m_node->get_output_element_type(m_index);
}
const Shape& Output<Node>::get_shape() const { return m_node->get_output_shape(m_index); }
const PartialShape& Output<Node>::get_partial_shape() const
{
return m_node->get_output_partial_shape(m_index);
}
std::set<Input<Node>> Output<Node>::get_target_inputs() const
{
std::set<Input<Node>> result;
for (auto& input : m_node->m_outputs.at(m_index).get_inputs())
{
result.emplace(input->get_raw_pointer_node(), input->get_index());
}
return result;
}
void Output<Node>::remove_target_input(const Input<Node>& target_input) const
{
m_node->m_outputs.at(m_index).remove_input(
&(target_input.get_node()->m_inputs.at(target_input.get_index())));
}
void Output<Node>::replace(const Output<Node>& replacement)
{
for (auto& input : get_target_inputs())
{
// GOEs are used as handles in passes
if (!is_type<op::GetOutputElement>(input.get_node()))
{
input.replace_source_output(replacement);
}
}
}
bool Output<Node>::operator==(const Output& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool Output<Node>::operator!=(const Output& other) const { return !(*this == other); }
bool Output<Node>::operator<(const Output& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool Output<Node>::operator>(const Output& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool Output<Node>::operator<=(const Output& other) const { return !(*this > other); }
bool Output<Node>::operator>=(const Output& other) const { return !(*this < other); }
Output<const Node>::Output(const Node* node, size_t index)
: m_node(node->shared_from_this())
, m_index(index)
{
}
Output<const Node>::Output(const std::shared_ptr<const Node>& node, size_t index)
: m_node(node)
, m_index(index)
{
}
void Output<const Node>::reset()
{
m_node.reset();
m_index = 0;
}
Output<const Node> Output<const Node>::for_node(const std::shared_ptr<const Node>& node)
{
return Output(node, m_index);
}
const Node* Output<const Node>::get_node() const { return m_node.get(); }
std::shared_ptr<const Node> Output<const Node>::get_node_shared_ptr() const { return m_node; }
size_t Output<const Node>::get_index() const { return m_index; }
descriptor::Tensor& Output<const Node>::get_tensor() const
{
return m_node->m_outputs.at(m_index).get_tensor();
}
std::shared_ptr<descriptor::Tensor> Output<const Node>::get_tensor_ptr() const
{
return m_node->m_outputs.at(m_index).get_tensor_ptr();
}
const element::Type& Output<const Node>::get_element_type() const
{
return m_node->get_output_element_type(m_index);
}
const Shape& Output<const Node>::get_shape() const { return m_node->get_output_shape(m_index); }
const PartialShape& Output<const Node>::get_partial_shape() const
{
return m_node->get_output_partial_shape(m_index);
}
std::set<Input<Node>> Output<const Node>::get_target_inputs() const
{
std::set<Input<Node>> result;
for (auto& input : m_node->m_outputs.at(m_index).get_inputs())
{
result.emplace(input->get_raw_pointer_node(), input->get_index());
}
return result;
}
bool Output<const Node>::operator==(const Output& other) const
{
return m_node == other.m_node && m_index == other.m_index;
}
bool Output<const Node>::operator!=(const Output& other) const { return !(*this == other); }
bool Output<const Node>::operator<(const Output& other) const
{
return m_node < other.m_node || (m_node == other.m_node && m_index < other.m_index);
}
bool Output<const Node>::operator>(const Output& other) const
{
return m_node > other.m_node || (m_node == other.m_node && m_index > other.m_index);
}
bool Output<const Node>::operator<=(const Output& other) const { return !(*this > other); }
bool Output<const Node>::operator>=(const Output& other) const { return !(*this < other); }
std::ostream& operator<<(std::ostream& out, const Output<Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Output<const Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstring>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
class Node;
template <typename NodeType>
class Input;
template <typename NodeType>
class Output
{
};
/// \brief A handle for one of a node's outputs.
template <>
class NGRAPH_API Output<Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(Node* node, size_t index);
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<Node>& node, size_t index);
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node)
: Output(node, 0)
{
}
/// A null output
Output() = default;
void reset();
/// This output position for a different node
Output<Node> for_node(const std::shared_ptr<Node>& node);
/// \return A pointer to the node referred to by this output handle.
Node* get_node() const;
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<Node> get_node_shared_ptr() const;
/// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const;
/// \return The index of the output referred to by this output handle.
size_t get_index() const;
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const;
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const;
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const;
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const;
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
/// \brief Removes a target input from the output referenced by this output handle.
/// \param target_input The target input to remove.
///
// TODO(amprocte): Investigate whether this really ought to be public.
void remove_target_input(const Input<Node>& target_input) const;
/// \brief Replace all users of this value with replacement
void replace(const Output<Node>& replacement);
bool operator==(const Output& other) const;
bool operator!=(const Output& other) const;
bool operator<(const Output& other) const;
bool operator>(const Output& other) const;
bool operator<=(const Output& other) const;
bool operator>=(const Output& other) const;
private:
std::shared_ptr<Node> m_node;
size_t m_index{0};
};
template <>
class NGRAPH_API Output<const Node>
{
public:
/// \brief Constructs a Output.
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(const Node* node, size_t index);
/// \brief Constructs a Output.
/// \param node A `shared_ptr` to the node for the output handle.
/// \param index The index of the output.
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<const Node>& node, size_t index);
/// \brief Constructs a Output, referencing the zeroth output of the node.
/// \param node A `shared_ptr` to the node for the output handle.
template <typename T>
Output(const std::shared_ptr<T>& node)
: Output(node, 0)
{
}
/// A null output
Output() = default;
void reset();
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node);
/// \return A pointer to the node referred to by this output handle.
const Node* get_node() const;
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<const Node> get_node_shared_ptr() const;
/// \return The index of the output referred to by this output handle.
size_t get_index() const;
/// \return A reference to the tensor descriptor for this output.
descriptor::Tensor& get_tensor() const;
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const;
/// \return The shape of the output referred to by this output handle.
const Shape& get_shape() const;
/// \return The partial shape of the output referred to by this output handle.
const PartialShape& get_partial_shape() const;
/// \return A set containing handles for all inputs targeted by the output referenced by
/// this output handle.
std::set<Input<Node>> get_target_inputs() const;
bool operator==(const Output& other) const;
bool operator!=(const Output& other) const;
bool operator<(const Output& other) const;
bool operator>(const Output& other) const;
bool operator<=(const Output& other) const;
bool operator>=(const Output& other) const;
private:
std::shared_ptr<const Node> m_node;
size_t m_index{0};
};
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment