Unverified Commit ca955d46 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

A couple new attribute visitors (#4233)

parent 46c21a0d
......@@ -27,19 +27,19 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Atan2::type_info;
op::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBroadcastSpec& autob)
op::v0::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(y, x, autob)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Atan2::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Atan2::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
void op::v0::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -51,3 +51,9 @@ void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVect
adjoints.add_delta(y, x * delta_over_r);
adjoints.add_delta(x, -y * delta_over_r);
}
bool op::v0::Atan2::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}
......@@ -24,30 +24,35 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise full arctan operation
class NGRAPH_API Atan2 : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
static constexpr NodeTypeInfo type_info{"Atan2", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Atan2()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
/// \brief Elementwise full arctan operation
class NGRAPH_API Atan2 : public util::BinaryElementwiseArithmetic
{
}
public:
static constexpr NodeTypeInfo type_info{"Atan2", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Atan2()
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NONE)
{
}
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// order).
///
/// \param y
/// \param x
Atan2(const Output<Node>& y,
const Output<Node>& x,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// order).
///
/// \param y
/// \param x
Atan2(const Output<Node>& y,
const Output<Node>& x,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
};
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
};
}
using v0::Atan2;
}
}
......@@ -142,20 +142,48 @@ shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector&
m_auto_pad);
}
bool op::v1::BinaryConvolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("pad_value", m_pad_value);
visitor.on_attribute("auto_pad", m_auto_pad);
return true;
}
void op::v1::BinaryConvolution::generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas)
{
throw ngraph_error("BinaryConvolution generate_adjoints not implemented");
}
op::v1::BinaryConvolution::BinaryConvolutionMode
op::v1::BinaryConvolution::mode_from_string(const std::string& mode) const
namespace ngraph
{
static const std::map<std::string, BinaryConvolutionMode> allowed_values = {
{"xnor-popcount", BinaryConvolutionMode::XNOR_POPCOUNT}};
template <>
EnumNames<op::v1::BinaryConvolution::BinaryConvolutionMode>&
EnumNames<op::v1::BinaryConvolution::BinaryConvolutionMode>::get()
{
static auto enum_names = EnumNames<op::v1::BinaryConvolution::BinaryConvolutionMode>(
"op::v1::BinaryConvolution::BinaryConvolutionMode",
{{"xnor-popcount", op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT}});
return enum_names;
}
NODE_VALIDATION_CHECK(
this, allowed_values.count(mode) > 0, "Invalid binary convolution mode value passed in.");
constexpr DiscreteTypeInfo
AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>::type_info;
return allowed_values.at(mode);
std::ostream& operator<<(std::ostream& s,
const op::v1::BinaryConvolution::BinaryConvolutionMode& type)
{
return s << as_string(type);
}
}
op::v1::BinaryConvolution::BinaryConvolutionMode
op::v1::BinaryConvolution::mode_from_string(const std::string& mode) const
{
return as_enum<BinaryConvolutionMode>(mode);
}
......@@ -74,6 +74,8 @@ namespace ngraph
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints,
......@@ -112,4 +114,23 @@ namespace ngraph
};
}
} // namespace op
std::ostream& operator<<(std::ostream& s,
const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>
: public EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>
{
public:
AttributeAdapter(op::v1::BinaryConvolution::BinaryConvolutionMode& value)
: EnumAttributeAdapterBase<op::v1::BinaryConvolution::BinaryConvolutionMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment