Unverified Commit 09b7e413 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Move some ops into v0 (#4138)

* Move some ops into v0

* namespace

* Make comments pretty

* Make comments pretty

* Merge fix
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent fbd99f34
......@@ -41,7 +41,10 @@ namespace ngraph
namespace op
{
class Parameter;
namespace v0
{
class Parameter;
}
}
void traverse_nodes(const std::shared_ptr<const Function> p,
......@@ -240,8 +243,8 @@ namespace ngraph
/// `body_replacement_map`, behavior is unspecified.
void replace_nodes(
const std::shared_ptr<Function>& f,
const std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Parameter>>&
parameter_replacement_map,
const std::unordered_map<std::shared_ptr<op::v0::Parameter>,
std::shared_ptr<op::v0::Parameter>>& parameter_replacement_map,
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>&
body_replacement_map);
......@@ -400,7 +403,7 @@ namespace ngraph
// Assert that nodes in the function is colocated and return that placement
Placement get_colocated_function_placement(std::shared_ptr<Function> func);
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::Parameter>>
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node);
......
......@@ -60,11 +60,14 @@ namespace ngraph
namespace op
{
struct AutoBroadcastSpec;
class Constant;
class Result;
namespace v0
{
class Result;
}
} // namespace op
using ResultVector = std::vector<std::shared_ptr<op::Result>>;
using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
namespace autodiff
{
......
......@@ -422,13 +422,16 @@ namespace ngraph
{
namespace op
{
template <>
void Constant::write_to_buffer<string>(const element::Type& /* target_type */,
const Shape& /* target_shape */,
const vector<string>& /* source */,
void* /* target */,
size_t /* target_element_count */)
namespace v0
{
template <>
void Constant::write_to_buffer<string>(const element::Type& /* target_type */,
const Shape& /* target_shape */,
const vector<string>& /* source */,
void* /* target */,
size_t /* target_element_count */)
{
}
}
}
}
This diff is collapsed.
......@@ -164,7 +164,7 @@ NGRAPH_OP(PRelu, ngraph::op::v0, 0)
NGRAPH_OP(PSROIPooling, ngraph::op::v0, 0)
NGRAPH_OP(Pad, ngraph::op::v0, 0)
NGRAPH_OP(Pad, ngraph::op::v1, 1)
NGRAPH_OP(Parameter, ngraph::op, 0)
NGRAPH_OP(Parameter, ngraph::op::v0, 0)
NGRAPH_OP(PartialSlice, ngraph::op::v0, 0)
NGRAPH_OP(PartialSliceBackprop, ngraph::op::v0, 0)
NGRAPH_OP(Passthrough, ngraph::op, 0)
......@@ -202,11 +202,11 @@ NGRAPH_OP(ReorgYolo, ngraph::op::v0, 0)
NGRAPH_OP(ReplaceSlice, ngraph::op::v0, 0)
NGRAPH_OP(Reshape, ngraph::op::v0, 0)
NGRAPH_OP(Reshape, ngraph::op::v1, 1)
NGRAPH_OP(Result, ngraph::op, 0)
NGRAPH_OP(Result, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v1, 1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
NGRAPH_OP(ScalarConstantLike, ngraph::op, 0)
NGRAPH_OP(ScalarConstantLike, ngraph::op::v0, 0)
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
NGRAPH_OP(ScatterND, ngraph::op::v0, 0)
......
......@@ -23,62 +23,66 @@ namespace ngraph
class Function;
namespace op
{
/// \brief A function parameter.
///
/// Parameters are nodes that represent the arguments that will be passed to user-defined
/// functions. Function creation requires a sequence of parameters. Basic graph operations
/// do not need parameters attached to a function.
class NGRAPH_API Parameter : public op::Op
namespace v0
{
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
public:
static constexpr NodeTypeInfo type_info{"Parameter", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructions a tensor-typed parameter node.
Parameter() = default;
/// \brief Constructions a tensor-typed parameter node.
/// \brief A function parameter.
///
/// \param element_type The element type of the parameter.
/// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type,
const PartialShape& pshape,
const bool cacheable = false);
/// Parameters are nodes that represent the arguments that will be passed to
/// user-defined functions. Function creation requires a sequence of parameters.
/// Basic graph operations do not need parameters attached to a function.
class NGRAPH_API Parameter : public op::Op
{
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
bool visit_attributes(AttributeVisitor& visitor) override;
public:
static constexpr NodeTypeInfo type_info{"Parameter", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructions a tensor-typed parameter node.
Parameter() = default;
/// \brief Constructions a tensor-typed parameter node.
///
/// \param element_type The element type of the parameter.
/// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type,
const PartialShape& pshape,
const bool cacheable = false);
bool is_parameter() const override { return true; }
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool get_cacheable() const { return m_cacheable; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool is_parameter() const override { return true; }
void validate_and_infer_types() override;
bool is_relevant_to_shapes() const;
void set_is_relevant_to_shapes(bool is_relevant);
bool get_cacheable() const { return m_cacheable; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const PartialShape& get_partial_shape() const { return m_partial_shape; }
PartialShape& get_partial_shape() { return m_partial_shape; }
void set_partial_shape(const PartialShape& partial_shape)
{
m_partial_shape = partial_shape;
}
bool is_relevant_to_shapes() const;
void set_is_relevant_to_shapes(bool is_relevant);
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
const PartialShape& get_partial_shape() const { return m_partial_shape; }
PartialShape& get_partial_shape() { return m_partial_shape; }
void set_partial_shape(const PartialShape& partial_shape)
{
m_partial_shape = partial_shape;
}
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
protected:
bool m_cacheable;
PartialShape m_partial_shape;
element::Type m_element_type;
bool m_is_relevant_to_shapes;
};
protected:
bool m_cacheable;
PartialShape m_partial_shape;
element::Type m_element_type;
bool m_is_relevant_to_shapes;
};
}
using v0::Parameter;
}
using ParameterVector = std::vector<std::shared_ptr<op::Parameter>>;
}
......@@ -24,33 +24,38 @@ namespace ngraph
{
namespace op
{
class NGRAPH_API Result : public Op
namespace v0
{
public:
static constexpr NodeTypeInfo type_info{"Result", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
private:
bool m_needs_default_layout{false};
};
class NGRAPH_API Result : public Op
{
public:
static constexpr NodeTypeInfo type_info{"Result", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
private:
bool m_needs_default_layout{false};
};
}
using v0::Result;
}
using ResultVector = std::vector<std::shared_ptr<op::Result>>;
}
......@@ -19,6 +19,7 @@
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
......
......@@ -115,7 +115,7 @@ NGRAPH_OP(OneHot, ngraph::op::v1)
NGRAPH_OP(PRelu, ngraph::op::v0)
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Parameter, ngraph::op::v0)
NGRAPH_OP(Power, ngraph::op::v1)
NGRAPH_OP(PriorBox, ngraph::op::v0)
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
......@@ -131,7 +131,7 @@ NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
NGRAPH_OP(RegionYolo, ngraph::op::v0)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Result, ngraph::op::v0)
NGRAPH_OP(Reverse, ngraph::op::v1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
NGRAPH_OP(RNNCell, ngraph::op::v0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment