Unverified Commit bd50f338 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Move annotations from Op to Node (#3738)

parent a4bf1c43
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/placement.hpp" #include "ngraph/placement.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -485,6 +486,15 @@ namespace ngraph ...@@ -485,6 +486,15 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs. /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
Output<const Node> output(size_t output_index) const; Output<const Node> output(size_t output_index) const;
void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations)
{
m_op_annotations = op_annotations;
}
std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const
{
return m_op_annotations;
}
private: private:
descriptor::Input& get_input_descriptor(size_t position); descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position); descriptor::Output& get_output_descriptor(size_t position);
...@@ -504,6 +514,7 @@ namespace ngraph ...@@ -504,6 +514,7 @@ namespace ngraph
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
Placement m_placement = Placement::DEFAULT; Placement m_placement = Placement::DEFAULT;
size_t m_placement_index = placement_invalid; size_t m_placement_index = placement_invalid;
std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
}; };
/// \brief A handle for one of a node's inputs. /// \brief A handle for one of a node's inputs.
......
...@@ -18,9 +18,7 @@ ...@@ -18,9 +18,7 @@
#include <string> #include <string>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/op_annotations.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -30,15 +28,6 @@ namespace ngraph ...@@ -30,15 +28,6 @@ namespace ngraph
class Op : public Node class Op : public Node
{ {
public: public:
void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations)
{
m_op_annotations = op_annotations;
}
std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const
{
return m_op_annotations;
}
virtual bool is_op() const override { return true; } virtual bool is_op() const override { return true; }
protected: protected:
Op() Op()
...@@ -48,9 +37,6 @@ namespace ngraph ...@@ -48,9 +37,6 @@ namespace ngraph
Op(const NodeVector& arguments); Op(const NodeVector& arguments);
Op(const OutputVector& arguments); Op(const OutputVector& arguments);
Op(const std::string& node_type, const NodeVector& arguments); Op(const std::string& node_type, const NodeVector& arguments);
private:
std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment