Unverified Commit 09b7e413 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Move some ops into v0 (#4138)

* Move some ops into v0

* namespace

* Make comments pretty

* Make comments pretty

* Merge fix
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent fbd99f34
......@@ -41,7 +41,10 @@ namespace ngraph
namespace op
{
class Parameter;
namespace v0
{
class Parameter;
}
}
void traverse_nodes(const std::shared_ptr<const Function> p,
......@@ -240,8 +243,8 @@ namespace ngraph
/// `body_replacement_map`, behavior is unspecified.
void replace_nodes(
const std::shared_ptr<Function>& f,
const std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Parameter>>&
parameter_replacement_map,
const std::unordered_map<std::shared_ptr<op::v0::Parameter>,
std::shared_ptr<op::v0::Parameter>>& parameter_replacement_map,
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>&
body_replacement_map);
......@@ -400,7 +403,7 @@ namespace ngraph
// Assert that nodes in the function is colocated and return that placement
Placement get_colocated_function_placement(std::shared_ptr<Function> func);
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::Parameter>>
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,
const std::shared_ptr<Node>& dst_node);
......
......@@ -60,11 +60,14 @@ namespace ngraph
namespace op
{
struct AutoBroadcastSpec;
class Constant;
class Result;
namespace v0
{
class Result;
}
} // namespace op
using ResultVector = std::vector<std::shared_ptr<op::Result>>;
using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
namespace autodiff
{
......
......@@ -422,13 +422,16 @@ namespace ngraph
{
namespace op
{
template <>
void Constant::write_to_buffer<string>(const element::Type& /* target_type */,
const Shape& /* target_shape */,
const vector<string>& /* source */,
void* /* target */,
size_t /* target_element_count */)
namespace v0
{
template <>
void Constant::write_to_buffer<string>(const element::Type& /* target_type */,
const Shape& /* target_shape */,
const vector<string>& /* source */,
void* /* target */,
size_t /* target_element_count */)
{
}
}
}
}
......@@ -30,381 +30,388 @@ namespace ngraph
{
namespace op
{
/// \brief Class for constants.
class NGRAPH_API Constant : public Op
namespace v0
{
public:
static constexpr NodeTypeInfo type_info{"Constant", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Constant() = default;
/// \brief Constructs a tensor constant.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A vector of literals for initializing the tensor constant. The size
/// of values must match the size of the shape.
template <typename T>
Constant(const element::Type& type, Shape shape, const std::vector<T>& values)
: m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment()))
/// \brief Class for constants.
class NGRAPH_API Constant : public Op
{
NODE_VALIDATION_CHECK(
this,
values.size() == 1 || values.size() == shape_size(m_shape),
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
(shape_size(m_shape) == 1 ? "" : "1 or "),
shape_size(m_shape),
").");
if (values.size() == 1)
{
write_values(std::vector<T>(shape_size(m_shape), values[0]));
}
else
public:
static constexpr NodeTypeInfo type_info{"Constant", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Constant() = default;
/// \brief Constructs a tensor constant.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A vector of literals for initializing the tensor constant. The
/// size of values must match the size of the shape.
template <typename T>
Constant(const element::Type& type, Shape shape, const std::vector<T>& values)
: m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment()))
{
write_values(values);
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
/// \brief Constructs a tensor constant
/// This constructor is mainly to support deserialization of constants.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data.
Constant(const element::Type& type,
Shape shape,
const std::vector<std::string>& values);
/// \brief Constructs a tensor constant with the same initialization value copied across
// the tensor. This constructor is to support deserialization of constants.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data);
virtual ~Constant() override;
NODE_VALIDATION_CHECK(
this,
values.size() == 1 || values.size() == shape_size(m_shape),
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
(shape_size(m_shape) == 1 ? "" : "1 or "),
shape_size(m_shape),
").");
void validate_and_infer_types() override
{
infer_element_type();
set_output_type(0, m_element_type, m_shape);
}
/// \brief Returns the value of the constant node as a Shape object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Shape get_shape_val() const;
/// \brief Returns the value of the constant node as a Strides
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Strides get_strides_val() const;
/// \brief Returns the value of the constant node as a Coordinate
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Coordinate get_coordinate_val() const;
/// \brief Returns the value of the constant node as a
/// CoordinateDiff object
/// Can only be used on element::i64 nodes.
CoordinateDiff get_coordinate_diff_val() const;
/// \brief Returns the value of the constant node as an AxisVector
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
AxisVector get_axis_vector_val() const;
/// \brief Returns the value of the constant node as an AxisSet
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
/// Repeated values are allowed.
AxisSet get_axis_set_val() const;
/// \brief Wrapper around constructing a shared_ptr of a Constant
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A vector of values to use as the constant data.
template <typename T>
static std::shared_ptr<op::Constant>
create(const element::Type& type, Shape shape, const std::vector<T> values)
{
auto result = std::make_shared<op::Constant>(type, shape, values);
result->validate_and_infer_types();
return result;
}
if (values.size() == 1)
{
write_values(std::vector<T>(shape_size(m_shape), values[0]));
}
else
{
write_values(values);
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
/// \brief Wrapper around constructing a shared_ptr of a Constant
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values An initializer_list of values to use as the constant data.
template <typename T>
static std::shared_ptr<op::Constant>
create(const element::Type& type, Shape shape, std::initializer_list<T> values)
{
auto result = std::make_shared<op::Constant>(type, shape, std::vector<T>{values});
result->validate_and_infer_types();
return result;
}
/// \brief Constructs a tensor constant
/// This constructor is mainly to support deserialization of constants.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data.
Constant(const element::Type& type,
Shape shape,
const std::vector<std::string>& values);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \brief Constructs a tensor constant with the same initialization value copied
/// across the tensor. This constructor is to support deserialization of
/// constants.
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data);
/// \return The initialization literals for the tensor constant.
std::vector<std::string> get_value_strings() const;
virtual ~Constant() override;
template <typename T>
std::vector<T> get_vector() const
{
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
void validate_and_infer_types() override
{
throw ngraph_error("Buffer over-read");
infer_element_type();
set_output_type(0, m_element_type, m_shape);
}
std::vector<T> rc;
const T* p = reinterpret_cast<const T*>(m_data->get_ptr());
for (size_t i = 0; i < shape_size(m_shape); i++)
{
rc.push_back(p[i]);
}
return rc;
}
/// \brief Returns the value of the constant node as a Shape object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Shape get_shape_val() const;
/// \brief Returns the value of the constant node as a Strides
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Strides get_strides_val() const;
/// \brief Returns the value of the constant node as a Coordinate
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Coordinate get_coordinate_val() const;
/// \brief Returns the value of the constant node as a
/// CoordinateDiff object
/// Can only be used on element::i64 nodes.
CoordinateDiff get_coordinate_diff_val() const;
/// \brief Returns the value of the constant node as an AxisVector
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
AxisVector get_axis_vector_val() const;
/// \brief Returns the value of the constant node as an AxisSet
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
/// Repeated values are allowed.
AxisSet get_axis_set_val() const;
/// \brief Return the Constant's value as a vector cast to type T
///
/// \tparam T Type to which data vector's entries will be cast.
/// \return Constant's data vector.
template <typename T>
std::vector<T> cast_vector() const
{
auto source_type = get_element_type();
switch (source_type)
{
case element::Type_t::boolean:
{
auto vector = get_vector<char>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::bf16:
{
auto vector = get_vector<bfloat16>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f16:
{
auto vector = get_vector<float16>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f32:
{
auto vector = get_vector<float>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f64:
{
auto vector = get_vector<double>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i8:
/// \brief Wrapper around constructing a shared_ptr of a Constant
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A vector of values to use as the constant data.
template <typename T>
static std::shared_ptr<op::v0::Constant>
create(const element::Type& type, Shape shape, const std::vector<T> values)
{
auto vector = get_vector<int8_t>();
return std::vector<T>(vector.begin(), vector.end());
auto result = std::make_shared<op::v0::Constant>(type, shape, values);
result->validate_and_infer_types();
return result;
}
case element::Type_t::i16:
/// \brief Wrapper around constructing a shared_ptr of a Constant
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values An initializer_list of values to use as the constant data.
template <typename T>
static std::shared_ptr<op::v0::Constant>
create(const element::Type& type, Shape shape, std::initializer_list<T> values)
{
auto vector = get_vector<int16_t>();
return std::vector<T>(vector.begin(), vector.end());
auto result =
std::make_shared<op::v0::Constant>(type, shape, std::vector<T>{values});
result->validate_and_infer_types();
return result;
}
case element::Type_t::i32:
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The initialization literals for the tensor constant.
std::vector<std::string> get_value_strings() const;
template <typename T>
std::vector<T> get_vector() const
{
auto vector = get_vector<int32_t>();
return std::vector<T>(vector.begin(), vector.end());
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
{
throw ngraph_error("Buffer over-read");
}
std::vector<T> rc;
const T* p = reinterpret_cast<const T*>(m_data->get_ptr());
for (size_t i = 0; i < shape_size(m_shape); i++)
{
rc.push_back(p[i]);
}
return rc;
}
case element::Type_t::i64:
/// \brief Return the Constant's value as a vector cast to type T
///
/// \tparam T Type to which data vector's entries will be cast.
/// \return Constant's data vector.
template <typename T>
std::vector<T> cast_vector() const
{
auto vector = get_vector<int64_t>();
return std::vector<T>(vector.begin(), vector.end());
auto source_type = get_element_type();
switch (source_type)
{
case element::Type_t::boolean:
{
auto vector = get_vector<char>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::bf16:
{
auto vector = get_vector<bfloat16>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f16:
{
auto vector = get_vector<float16>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f32:
{
auto vector = get_vector<float>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f64:
{
auto vector = get_vector<double>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i8:
{
auto vector = get_vector<int8_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i16:
{
auto vector = get_vector<int16_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i32:
{
auto vector = get_vector<int32_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i64:
{
auto vector = get_vector<int64_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u8:
{
auto vector = get_vector<uint8_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u16:
{
auto vector = get_vector<uint16_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u32:
{
auto vector = get_vector<uint32_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u64:
{
auto vector = get_vector<uint64_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
}
case element::Type_t::u8:
const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); }
template <typename T>
const T* get_data_ptr() const
{
auto vector = get_vector<uint8_t>();
return std::vector<T>(vector.begin(), vector.end());
return reinterpret_cast<const T*>(get_data_ptr());
}
case element::Type_t::u16:
bool is_constant() const override { return true; }
bool get_all_data_elements_bitwise_identical() const
{
auto vector = get_vector<uint16_t>();
return std::vector<T>(vector.begin(), vector.end());
return m_all_elements_bitwise_identical;
}
case element::Type_t::u32:
std::string convert_value_to_string(size_t index) const;
protected:
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
Constant(const OutputVector& args)
: Op(args)
, m_shape({})
{
auto vector = get_vector<uint32_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u64:
virtual void infer_element_type() {}
template <typename T>
void write_values(const std::vector<T>& values)
{
auto vector = get_vector<uint64_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
write_to_buffer(
m_element_type, m_shape, values, get_data_ptr_nc(), shape_size(m_shape));
}
}
const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); }
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<const T*>(get_data_ptr());
}
bool is_constant() const override { return true; }
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;
}
std::string convert_value_to_string(size_t index) const;
protected:
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
Constant(const OutputVector& args)
: Op(args)
, m_shape({})
{
}
virtual void infer_element_type() {}
template <typename T>
void write_values(const std::vector<T>& values)
{
write_to_buffer(
m_element_type, m_shape, values, get_data_ptr_nc(), shape_size(m_shape));
}
template <typename T, typename U>
void write_buffer(void* target, const std::vector<U>& source, size_t count)
{
T* p = reinterpret_cast<T*>(target);
for (size_t i = 0; i < count; i++)
template <typename T, typename U>
void write_buffer(void* target, const std::vector<U>& source, size_t count)
{
p[i] = static_cast<T>(source[i]);
T* p = reinterpret_cast<T*>(target);
for (size_t i = 0; i < count; i++)
{
p[i] = static_cast<T>(source[i]);
}
}
}
template <typename T>
void write_to_buffer(const element::Type& target_type,
const Shape& /* target_shape */,
const std::vector<T>& source,
void* target,
size_t target_element_count)
{
if (source.size() != target_element_count)
template <typename T>
void write_to_buffer(const element::Type& target_type,
const Shape& /* target_shape */,
const std::vector<T>& source,
void* target,
size_t target_element_count)
{
throw std::runtime_error("Constant initializer does not match shape");
}
if (source.size() != target_element_count)
{
throw std::runtime_error("Constant initializer does not match shape");
}
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (target_type)
{
case element::Type_t::boolean:
write_buffer<char, T>(target, source, target_element_count);
break;
case element::Type_t::bf16:
write_buffer<bfloat16, T>(target, source, target_element_count);
break;
case element::Type_t::f16:
write_buffer<float16, T>(target, source, target_element_count);
break;
case element::Type_t::f32:
write_buffer<float, T>(target, source, target_element_count);
break;
case element::Type_t::f64:
write_buffer<double, T>(target, source, target_element_count);
break;
case element::Type_t::i8:
write_buffer<int8_t, T>(target, source, target_element_count);
break;
case element::Type_t::i16:
write_buffer<int16_t, T>(target, source, target_element_count);
break;
case element::Type_t::i32:
write_buffer<int32_t, T>(target, source, target_element_count);
break;
case element::Type_t::i64:
write_buffer<int64_t, T>(target, source, target_element_count);
break;
case element::Type_t::u8:
write_buffer<uint8_t, T>(target, source, target_element_count);
break;
case element::Type_t::u16:
write_buffer<uint16_t, T>(target, source, target_element_count);
break;
case element::Type_t::u32:
write_buffer<uint32_t, T>(target, source, target_element_count);
break;
case element::Type_t::u64:
write_buffer<uint64_t, T>(target, source, target_element_count);
break;
case element::Type_t::u1: throw std::runtime_error("unsupported type");
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
switch (target_type)
{
case element::Type_t::boolean:
write_buffer<char, T>(target, source, target_element_count);
break;
case element::Type_t::bf16:
write_buffer<bfloat16, T>(target, source, target_element_count);
break;
case element::Type_t::f16:
write_buffer<float16, T>(target, source, target_element_count);
break;
case element::Type_t::f32:
write_buffer<float, T>(target, source, target_element_count);
break;
case element::Type_t::f64:
write_buffer<double, T>(target, source, target_element_count);
break;
case element::Type_t::i8:
write_buffer<int8_t, T>(target, source, target_element_count);
break;
case element::Type_t::i16:
write_buffer<int16_t, T>(target, source, target_element_count);
break;
case element::Type_t::i32:
write_buffer<int32_t, T>(target, source, target_element_count);
break;
case element::Type_t::i64:
write_buffer<int64_t, T>(target, source, target_element_count);
break;
case element::Type_t::u8:
write_buffer<uint8_t, T>(target, source, target_element_count);
break;
case element::Type_t::u16:
write_buffer<uint16_t, T>(target, source, target_element_count);
break;
case element::Type_t::u32:
write_buffer<uint32_t, T>(target, source, target_element_count);
break;
case element::Type_t::u64:
write_buffer<uint64_t, T>(target, source, target_element_count);
break;
case element::Type_t::u1: throw std::runtime_error("unsupported type");
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#pragma GCC diagnostic pop
}
}
static constexpr size_t host_alignment() { return 64; }
element::Type m_element_type;
Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
};
static constexpr size_t host_alignment() { return 64; }
element::Type m_element_type;
Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
};
/// \brief A scalar constant whose element type is the same as like.
class NGRAPH_API ScalarConstantLike : public Constant
{
public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLike", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief A scalar constant whose element type is the same as like.
///
/// Once the element type is known, the dependency on like will be removed and
/// this node will be replaced with an equivalent constant.
///
/// \param like A tensor that will supply the element type.
/// \param value The value of the scalar.
template <typename T>
ScalarConstantLike(const Output<Node>& like, T value)
: Constant({like})
, m_value(static_cast<double>(value))
class NGRAPH_API ScalarConstantLike : public Constant
{
constructor_validate_and_infer_types();
}
public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLike", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief A scalar constant whose element type is the same as like.
///
/// Once the element type is known, the dependency on like will be removed and
/// this node will be replaced with an equivalent constant.
///
/// \param like A tensor that will supply the element type.
/// \param value The value of the scalar.
template <typename T>
ScalarConstantLike(const Output<Node>& like, T value)
: Constant({like})
, m_value(static_cast<double>(value))
{
constructor_validate_and_infer_types();
}
ScalarConstantLike() = default;
ScalarConstantLike() = default;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<op::Constant> as_constant() const;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<op::v0::Constant> as_constant() const;
protected:
void infer_element_type() override;
protected:
void infer_element_type() override;
double m_value;
};
double m_value;
};
}
using v0::Constant;
using v0::ScalarConstantLike;
}
}
......@@ -164,7 +164,7 @@ NGRAPH_OP(PRelu, ngraph::op::v0, 0)
NGRAPH_OP(PSROIPooling, ngraph::op::v0, 0)
NGRAPH_OP(Pad, ngraph::op::v0, 0)
NGRAPH_OP(Pad, ngraph::op::v1, 1)
NGRAPH_OP(Parameter, ngraph::op, 0)
NGRAPH_OP(Parameter, ngraph::op::v0, 0)
NGRAPH_OP(PartialSlice, ngraph::op::v0, 0)
NGRAPH_OP(PartialSliceBackprop, ngraph::op::v0, 0)
NGRAPH_OP(Passthrough, ngraph::op, 0)
......@@ -202,11 +202,11 @@ NGRAPH_OP(ReorgYolo, ngraph::op::v0, 0)
NGRAPH_OP(ReplaceSlice, ngraph::op::v0, 0)
NGRAPH_OP(Reshape, ngraph::op::v0, 0)
NGRAPH_OP(Reshape, ngraph::op::v1, 1)
NGRAPH_OP(Result, ngraph::op, 0)
NGRAPH_OP(Result, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v1, 1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
NGRAPH_OP(ScalarConstantLike, ngraph::op, 0)
NGRAPH_OP(ScalarConstantLike, ngraph::op::v0, 0)
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
NGRAPH_OP(ScatterND, ngraph::op::v0, 0)
......
......@@ -23,62 +23,66 @@ namespace ngraph
class Function;
namespace op
{
/// \brief A function parameter.
///
/// Parameters are nodes that represent the arguments that will be passed to user-defined
/// functions. Function creation requires a sequence of parameters. Basic graph operations
/// do not need parameters attached to a function.
class NGRAPH_API Parameter : public op::Op
namespace v0
{
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
public:
static constexpr NodeTypeInfo type_info{"Parameter", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructions a tensor-typed parameter node.
Parameter() = default;
/// \brief Constructions a tensor-typed parameter node.
/// \brief A function parameter.
///
/// \param element_type The element type of the parameter.
/// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type,
const PartialShape& pshape,
const bool cacheable = false);
/// Parameters are nodes that represent the arguments that will be passed to
/// user-defined functions. Function creation requires a sequence of parameters.
/// Basic graph operations do not need parameters attached to a function.
class NGRAPH_API Parameter : public op::Op
{
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
bool visit_attributes(AttributeVisitor& visitor) override;
public:
static constexpr NodeTypeInfo type_info{"Parameter", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructions a tensor-typed parameter node.
Parameter() = default;
/// \brief Constructions a tensor-typed parameter node.
///
/// \param element_type The element type of the parameter.
/// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type,
const PartialShape& pshape,
const bool cacheable = false);
bool is_parameter() const override { return true; }
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool get_cacheable() const { return m_cacheable; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool is_parameter() const override { return true; }
void validate_and_infer_types() override;
bool is_relevant_to_shapes() const;
void set_is_relevant_to_shapes(bool is_relevant);
bool get_cacheable() const { return m_cacheable; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const PartialShape& get_partial_shape() const { return m_partial_shape; }
PartialShape& get_partial_shape() { return m_partial_shape; }
void set_partial_shape(const PartialShape& partial_shape)
{
m_partial_shape = partial_shape;
}
bool is_relevant_to_shapes() const;
void set_is_relevant_to_shapes(bool is_relevant);
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
const PartialShape& get_partial_shape() const { return m_partial_shape; }
PartialShape& get_partial_shape() { return m_partial_shape; }
void set_partial_shape(const PartialShape& partial_shape)
{
m_partial_shape = partial_shape;
}
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
protected:
bool m_cacheable;
PartialShape m_partial_shape;
element::Type m_element_type;
bool m_is_relevant_to_shapes;
};
protected:
bool m_cacheable;
PartialShape m_partial_shape;
element::Type m_element_type;
bool m_is_relevant_to_shapes;
};
}
using v0::Parameter;
}
using ParameterVector = std::vector<std::shared_ptr<op::Parameter>>;
}
......@@ -24,33 +24,38 @@ namespace ngraph
{
namespace op
{
class NGRAPH_API Result : public Op
namespace v0
{
public:
static constexpr NodeTypeInfo type_info{"Result", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
private:
bool m_needs_default_layout{false};
};
class NGRAPH_API Result : public Op
{
public:
static constexpr NodeTypeInfo type_info{"Result", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
private:
bool m_needs_default_layout{false};
};
}
using v0::Result;
}
using ResultVector = std::vector<std::shared_ptr<op::Result>>;
}
......@@ -19,6 +19,7 @@
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
......
......@@ -115,7 +115,7 @@ NGRAPH_OP(OneHot, ngraph::op::v1)
NGRAPH_OP(PRelu, ngraph::op::v0)
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Parameter, ngraph::op::v0)
NGRAPH_OP(Power, ngraph::op::v1)
NGRAPH_OP(PriorBox, ngraph::op::v0)
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
......@@ -131,7 +131,7 @@ NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
NGRAPH_OP(RegionYolo, ngraph::op::v0)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Result, ngraph::op::v0)
NGRAPH_OP(Reverse, ngraph::op::v1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
NGRAPH_OP(RNNCell, ngraph::op::v0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment