Unverified Commit ca955d46 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

A couple new attribute visitors (#4233)

parent 46c21a0d
...@@ -27,19 +27,19 @@ using namespace ngraph; ...@@ -27,19 +27,19 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Atan2::type_info; constexpr NodeTypeInfo op::Atan2::type_info;
op::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBroadcastSpec& autob) op::v0::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(y, x, autob) : BinaryElementwiseArithmetic(y, x, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Atan2::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Atan2::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas) void op::v0::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -51,3 +51,9 @@ void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVect ...@@ -51,3 +51,9 @@ void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVect
adjoints.add_delta(y, x * delta_over_r); adjoints.add_delta(y, x * delta_over_r);
adjoints.add_delta(x, -y * delta_over_r); adjoints.add_delta(x, -y * delta_over_r);
} }
bool op::v0::Atan2::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}
...@@ -24,30 +24,35 @@ namespace ngraph ...@@ -24,30 +24,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise full arctan operation namespace v0
class NGRAPH_API Atan2 : public util::BinaryElementwiseArithmetic
{ {
public: /// \brief Elementwise full arctan operation
static constexpr NodeTypeInfo type_info{"Atan2", 0}; class NGRAPH_API Atan2 : public util::BinaryElementwiseArithmetic
const NodeTypeInfo& get_type_info() const override { return type_info; }
Atan2()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{ {
} public:
static constexpr NodeTypeInfo type_info{"Atan2", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Atan2()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed /// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// order). /// order).
/// ///
/// \param y /// \param y
/// \param x /// \param x
Atan2(const Output<Node>& y, Atan2(const Output<Node>& y,
const Output<Node>& x, const Output<Node>& x,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override; const OutputVector& deltas) override;
}; };
}
using v0::Atan2;
} }
} }
...@@ -142,20 +142,48 @@ shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector& ...@@ -142,20 +142,48 @@ shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector&
m_auto_pad); m_auto_pad);
} }
bool op::v1::BinaryConvolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("pad_value", m_pad_value);
visitor.on_attribute("auto_pad", m_auto_pad);
return true;
}
void op::v1::BinaryConvolution::generate_adjoints(autodiff::Adjoints& adjoints, void op::v1::BinaryConvolution::generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) const OutputVector& deltas)
{ {
throw ngraph_error("BinaryConvolution generate_adjoints not implemented"); throw ngraph_error("BinaryConvolution generate_adjoints not implemented");
} }
op::v1::BinaryConvolution::BinaryConvolutionMode namespace ngraph
op::v1::BinaryConvolution::mode_from_string(const std::string& mode) const
{ {
static const std::map<std::string, BinaryConvolutionMode> allowed_values = { template <>
{"xnor-popcount", BinaryConvolutionMode::XNOR_POPCOUNT}}; EnumNames<op::v1::BinaryConvolution::BinaryConvolutionMode>&
EnumNames<op::v1::BinaryConvolution::BinaryConvolutionMode>::get()
{
static auto enum_names = EnumNames<op::v1::BinaryConvolution::BinaryConvolutionMode>(
"op::v1::BinaryConvolution::BinaryConvolutionMode",
{{"xnor-popcount", op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT}});
return enum_names;
}
NODE_VALIDATION_CHECK( constexpr DiscreteTypeInfo
this, allowed_values.count(mode) > 0, "Invalid binary convolution mode value passed in."); AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>::type_info;
return allowed_values.at(mode); std::ostream& operator<<(std::ostream& s,
const op::v1::BinaryConvolution::BinaryConvolutionMode& type)
{
return s << as_string(type);
}
}
op::v1::BinaryConvolution::BinaryConvolutionMode
op::v1::BinaryConvolution::mode_from_string(const std::string& mode) const
{
return as_enum<BinaryConvolutionMode>(mode);
} }
...@@ -74,6 +74,8 @@ namespace ngraph ...@@ -74,6 +74,8 @@ namespace ngraph
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints, void generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -112,4 +114,23 @@ namespace ngraph ...@@ -112,4 +114,23 @@ namespace ngraph
}; };
} }
} // namespace op } // namespace op
std::ostream& operator<<(std::ostream& s,
const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>
{
public:
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph } // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment