Unverified Commit 43e393e6 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Add some attribute visitors (#4279)

* Add some attribute visitors

* Need to keep size_t for opset1

* Fix condition

* Misso

* Missed int64_t

* Revert std changes

* Header reorg
Co-authored-by: 's avataraslepko <44713115+aslepko@users.noreply.github.com>
parent 61ce7176
......@@ -201,6 +201,26 @@ namespace ngraph
m_buffer_valid = false;
}
#ifdef __APPLE__
// size_t is not uint_64t on OSX
constexpr DiscreteTypeInfo AttributeAdapter<size_t>::type_info;
const int64_t& AttributeAdapter<size_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<size_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
#endif
constexpr DiscreteTypeInfo AttributeAdapter<vector<int64_t>>::type_info;
const vector<int64_t>& AttributeAdapter<vector<int64_t>>::get() { return m_value; }
......
......@@ -246,6 +246,25 @@ namespace ngraph
void set(const int64_t& value) override;
};
#ifdef __APPLE__
// size_t is one of the uint types on _WIN32
template <>
class NGRAPH_API AttributeAdapter<size_t> : public ValueReference<size_t>,
public ValueAccessor<int64_t>
{
public:
AttributeAdapter(size_t& value)
: ValueReference<size_t>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<size_t>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const int64_t& get() override;
void set(const int64_t& value) override;
};
#endif
/// Note: These class bodies cannot be defined with templates because of interactions
/// between dllexport and templates on Windows.
template <>
......
......@@ -29,6 +29,7 @@
#include <unordered_set>
#include <vector>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/check.hpp"
#include "ngraph/coordinate.hpp"
......
......@@ -35,6 +35,12 @@ void op::Convert::validate_and_infer_types()
set_output_type(0, m_destination_type, get_input_partial_shape(0));
}
bool op::Convert::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("destination_type", m_destination_type);
return true;
}
shared_ptr<Node> op::Convert::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -39,7 +39,7 @@ namespace ngraph
Convert(const Output<Node>& arg, const ngraph::element::Type& destination_type);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_destination_type() const { return m_destination_type; }
......
......@@ -34,6 +34,11 @@ void op::v1::ConvertLike::validate_and_infer_types()
set_output_type(0, get_input_element_type(1), get_input_partial_shape(0));
}
bool op::v1::ConvertLike::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::v1::ConvertLike::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -39,6 +39,7 @@ namespace ngraph
ConvertLike(const Output<Node>& data, const Output<Node>& like);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -45,6 +45,16 @@ op::v1::Convolution::Convolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types();
}
bool op::v1::Convolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("auto_pad", m_auto_pad);
return true;
}
void op::v1::Convolution::validate_and_infer_types()
{
const PartialShape& data_batch_shape = get_input_partial_shape(0);
......@@ -185,6 +195,17 @@ op::v1::ConvolutionBackpropData::ConvolutionBackpropData(const Output<Node>& dat
constructor_validate_and_infer_types();
}
bool op::v1::ConvolutionBackpropData::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("auto_pad", m_auto_pad);
visitor.on_attribute("output_padding", m_output_padding);
return true;
}
op::v1::ConvolutionBackpropData::ConvolutionBackpropData(const Output<Node>& data,
const Output<Node>& filters,
const Strides& strides,
......@@ -498,6 +519,15 @@ op::v1::ConvolutionBackpropFilters::ConvolutionBackpropFilters(const Output<Node
constructor_validate_and_infer_types();
}
bool op::v1::ConvolutionBackpropFilters::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
return true;
}
const Shape op::v1::ConvolutionBackpropFilters::get_filters_shape() const
{
Shape shape;
......@@ -647,6 +677,17 @@ op::v0::Convolution::Convolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types();
}
bool op::v0::Convolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("window_movement_strides", m_window_movement_strides);
visitor.on_attribute("window_dilation_strides", m_window_dilation_strides);
visitor.on_attribute("data_dilation_strides", m_data_dilation_strides);
visitor.on_attribute("padding_below", m_padding_below);
visitor.on_attribute("padding_above", m_padding_above);
visitor.on_attribute("pad_type", m_pad_type);
return true;
}
void op::v0::Convolution::validate_and_infer_types()
{
const PartialShape& data_batch_shape = get_input_partial_shape(0);
......@@ -839,6 +880,17 @@ op::v0::ConvolutionBackpropData::ConvolutionBackpropData(
constructor_validate_and_infer_types();
}
bool op::v0::ConvolutionBackpropData::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("data_batch_shape", m_data_batch_shape);
visitor.on_attribute("window_movement_strides_forward", m_window_movement_strides_forward);
visitor.on_attribute("window_dilation_strides_forward", m_window_dilation_strides_forward);
visitor.on_attribute("padding_below_forward", m_padding_below_forward);
visitor.on_attribute("padding_above_forward", m_padding_above_forward);
visitor.on_attribute("data_dilation_strides_forward", m_data_dilation_strides_forward);
return true;
}
void op::v0::ConvolutionBackpropData::validate_and_infer_types()
{
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
......@@ -1070,6 +1122,17 @@ op::v0::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
constructor_validate_and_infer_types();
}
bool op::v0::ConvolutionBackpropFilters::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("m_filters_shape", m_filters_shape);
visitor.on_attribute("window_movement_strides_forward", m_window_movement_strides_forward);
visitor.on_attribute("window_dilation_strides_forward", m_window_dilation_strides_forward);
visitor.on_attribute("padding_below_forward", m_padding_below_forward);
visitor.on_attribute("padding_above_forward", m_padding_above_forward);
visitor.on_attribute("data_dilation_strides_forward", m_data_dilation_strides_forward);
return true;
}
void op::v0::ConvolutionBackpropFilters::validate_and_infer_types()
{
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
......
......@@ -63,6 +63,7 @@ namespace ngraph
const PadType& auto_pad = PadType::EXPLICIT);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -156,8 +157,9 @@ namespace ngraph
const PadType& auto_pad = PadType::EXPLICIT,
const CoordinateDiff& output_padding = {});
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_dynamic() const override;
virtual void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
......@@ -224,6 +226,7 @@ namespace ngraph
const CoordinateDiff& pads_end);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -374,6 +377,7 @@ namespace ngraph
Convolution(const Output<Node>& data_batch, const Output<Node>& filters);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -468,6 +472,7 @@ namespace ngraph
const Strides& data_dilation_strides_forward);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
......@@ -570,6 +575,7 @@ namespace ngraph
const Strides& data_dilation_strides_forward);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -30,6 +30,11 @@ op::Cos::Cos(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool op::Cos::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Cos::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -36,6 +36,7 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Cos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -29,6 +29,11 @@ op::Cosh::Cosh(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool op::Cosh::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Cosh::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -36,6 +36,7 @@ namespace ngraph
///
/// \param arg Node that produces the input tensor.
Cosh(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -49,6 +49,18 @@ op::v1::DeformableConvolution::DeformableConvolution(const Output<Node>& arg,
constructor_validate_and_infer_types();
}
bool op::v1::DeformableConvolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("auto_pad", m_auto_pad);
visitor.on_attribute("group", m_group);
visitor.on_attribute("deformable_group", m_deformable_group);
return true;
}
void op::v1::DeformableConvolution::validate_and_infer_types()
{
const PartialShape& data_batch_shape = get_input_partial_shape(0);
......
......@@ -65,6 +65,7 @@ namespace ngraph
const PadType& auto_pad = PadType::EXPLICIT,
const size_t group = 1,
const size_t deformable_group = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
......
......@@ -68,6 +68,19 @@ op::v1::DeformablePSROIPooling::DeformablePSROIPooling(const Output<Node>& input
constructor_validate_and_infer_types();
}
bool op::v1::DeformablePSROIPooling::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("output_dim", m_output_dim);
visitor.on_attribute("spatial_scale", m_spatial_scale);
visitor.on_attribute("group_size", m_group_size);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("spatial_bins_x", m_spatial_bins_x);
visitor.on_attribute("spatial_bins_y", m_spatial_bins_y);
visitor.on_attribute("trans_std", m_trans_std);
visitor.on_attribute("part_size", m_part_size);
return true;
}
void op::v1::DeformablePSROIPooling::validate_and_infer_types()
{
const auto& input_et = get_input_element_type(0);
......
......@@ -76,6 +76,8 @@ namespace ngraph
float trans_std = 1,
int64_t part_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -43,6 +43,13 @@ op::v0::Divide::Divide(const Output<Node>& arg0,
constructor_validate_and_infer_types();
}
bool op::v0::Divide::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
visitor.on_attribute("m_pythondiv", m_pythondiv);
return true;
}
shared_ptr<Node> op::v0::Divide::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......@@ -93,6 +100,13 @@ op::v1::Divide::Divide(const Output<Node>& arg0,
constructor_validate_and_infer_types();
}
bool op::v1::Divide::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
visitor.on_attribute("m_pythondiv", m_pythondiv);
return true;
}
shared_ptr<Node> op::v1::Divide::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -54,7 +54,7 @@ namespace ngraph
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
......@@ -103,7 +103,7 @@ namespace ngraph
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
......
......@@ -50,6 +50,12 @@ void op::CTCGreedyDecoder::validate_and_infer_types()
}
}
bool op::CTCGreedyDecoder::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("ctc_merge_repeated", m_ctc_merge_repeated);
return true;
}
shared_ptr<Node> op::CTCGreedyDecoder::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -40,7 +40,7 @@ namespace ngraph
const bool ctc_merge_repeated);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -60,3 +60,10 @@ shared_ptr<Node> op::Clamp::copy_with_new_args(const NodeVector& new_args) const
return make_shared<Clamp>(new_args.at(0), m_min, m_max);
}
bool op::Clamp::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("min", m_min);
visitor.on_attribute("max", m_max);
return true;
}
......@@ -51,6 +51,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
double get_min() const { return m_min; }
double get_max() const { return m_max; }
private:
......
......@@ -44,6 +44,13 @@ op::DepthToSpace::DepthToSpace(const Output<Node>& data,
{
}
bool op::DepthToSpace::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("blocksize", m_blocksize);
visitor.on_attribute("m_mode", m_mode);
return true;
}
NodeVector op::DepthToSpace::decompose_op() const
{
auto data = input_value(0);
......@@ -153,14 +160,28 @@ shared_ptr<Node> op::DepthToSpace::copy_with_new_args(const NodeVector& new_args
return make_shared<DepthToSpace>(new_args.at(0), m_mode, m_blocksize);
}
op::DepthToSpace::DepthToSpaceMode op::DepthToSpace::mode_from_string(const std::string& mode) const
namespace ngraph
{
static const std::map<std::string, DepthToSpaceMode> allowed_values = {
{"blocks_first", DepthToSpaceMode::BLOCKS_FIRST},
{"depth_first", DepthToSpaceMode::DEPTH_FIRST}};
template <>
EnumNames<op::DepthToSpace::DepthToSpaceMode>&
EnumNames<op::DepthToSpace::DepthToSpaceMode>::get()
{
static auto enum_names = EnumNames<op::DepthToSpace::DepthToSpaceMode>(
"op::DepthToSpace::DepthToSpaceMode",
{{"blocks_first", op::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST},
{"depth_first", op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST}});
return enum_names;
}
NODE_VALIDATION_CHECK(
this, allowed_values.count(mode) > 0, "Invalid 'depth_to_space_mode' value passed in.");
constexpr DiscreteTypeInfo AttributeAdapter<op::DepthToSpace::DepthToSpaceMode>::type_info;
return allowed_values.at(mode);
std::ostream& operator<<(std::ostream& s, const op::DepthToSpace::DepthToSpaceMode& type)
{
return s << as_string(type);
}
}
op::DepthToSpace::DepthToSpaceMode op::DepthToSpace::mode_from_string(const std::string& mode) const
{
return as_enum<DepthToSpaceMode>(mode);
}
......@@ -18,6 +18,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
......@@ -61,6 +62,7 @@ namespace ngraph
DepthToSpace(const Output<Node>& data,
const std::string& mode,
std::size_t block_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
std::size_t get_block_size() const { return m_blocksize; }
DepthToSpaceMode get_mode() const { return m_mode; }
......@@ -77,4 +79,20 @@ namespace ngraph
}
using v0::DepthToSpace;
}
std::ostream& operator<<(std::ostream& s, const op::v0::DepthToSpace::DepthToSpaceMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>
: public EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>
{
public:
AttributeAdapter(op::v0::DepthToSpace::DepthToSpaceMode& value)
: EnumAttributeAdapterBase<op::v0::DepthToSpace::DepthToSpaceMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::DepthToSpace::DepthToSpaceMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment