Unverified Commit 4dc9aa46 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Move non-primitive attribute adapters to adaptee's files (#3949)

* Move non-primitive attribute adapters to adaptee's files

* Cast in copy
parent cc754735
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "ngraph/attribute_adapter.hpp" #include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/partial_shape.hpp" #include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
...@@ -27,20 +29,6 @@ ...@@ -27,20 +29,6 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
namespace
{
template <typename A, typename B>
A copy_from(B& b)
{
A result(b.size());
for (int i = 0; i < b.size(); ++i)
{
result[i] = b[i];
}
return result;
}
}
namespace ngraph namespace ngraph
{ {
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<float>::type_info; NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<float>::type_info;
...@@ -234,68 +222,4 @@ namespace ngraph ...@@ -234,68 +222,4 @@ namespace ngraph
m_value = copy_from<vector<uint64_t>>(value); m_value = copy_from<vector<uint64_t>>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Shape>::type_info;
const vector<int64_t>& AttributeAdapter<Shape>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Shape>::set(const vector<int64_t>& value)
{
m_value = copy_from<Shape>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Strides>::type_info;
const vector<int64_t>& AttributeAdapter<Strides>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Strides>::set(const vector<int64_t>& value)
{
m_value = copy_from<Strides>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<AxisSet>::type_info;
const vector<int64_t>& AttributeAdapter<AxisSet>::get()
{
if (!m_buffer_valid)
{
for (auto elt : m_value)
{
m_buffer.push_back(elt);
}
}
return m_buffer;
}
void AttributeAdapter<AxisSet>::set(const vector<int64_t>& value)
{
m_value = AxisSet();
for (auto elt : value)
{
m_value.insert(elt);
}
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<PartialShape>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastSpec>::type_info;
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <type_traits>
#include <vector> #include <vector>
#include "ngraph/enum_names.hpp" #include "ngraph/enum_names.hpp"
...@@ -272,107 +273,15 @@ namespace ngraph ...@@ -272,107 +273,15 @@ namespace ngraph
void set(const std::vector<int64_t>& value) override; void set(const std::vector<int64_t>& value) override;
}; };
class Shape; template <typename A, typename B>
template <> A copy_from(B& b)
class AttributeAdapter<Shape> : public ValueReference<Shape>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(Shape& value)
: ValueReference<Shape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Shape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
class Strides;
template <>
class AttributeAdapter<Strides> : public ValueReference<Strides>,
public ValueAccessor<std::vector<int64_t>>
{ {
public: A result(b.size());
AttributeAdapter(Strides& value) for (int i = 0; i < b.size(); ++i)
: ValueReference<Strides>(value)
{ {
result[i] =
static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
} }
NGRAPH_API return result;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Strides>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
class AxisSet;
template <>
class AttributeAdapter<AxisSet> : public ValueReference<AxisSet>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(AxisSet& value)
: ValueReference<AxisSet>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
class PartialShape;
template <>
class AttributeAdapter<PartialShape> : public ValueReference<PartialShape>,
public ValueAccessor<void>
{
public:
AttributeAdapter(PartialShape& value)
: ValueReference<PartialShape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<PartialShape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace element
{
class Type;
}
template <>
class AttributeAdapter<element::Type> : public ValueReference<element::Type>,
public ValueAccessor<void>
{
public:
AttributeAdapter(element::Type& value)
: ValueReference<element::Type>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<element::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
struct AutoBroadcastSpec;
} }
template <>
class AttributeAdapter<op::AutoBroadcastSpec> : public ValueReference<op::AutoBroadcastSpec>,
public ValueAccessor<void>
{
public:
AttributeAdapter(op::AutoBroadcastSpec& value)
: ValueReference<op::AutoBroadcastSpec>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastSpec>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} }
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set) std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set)
{ {
s << "AxisSet{"; s << "AxisSet{";
...@@ -24,3 +27,27 @@ std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set) ...@@ -24,3 +27,27 @@ std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set)
s << "}"; s << "}";
return s; return s;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<AxisSet>::type_info;
const vector<int64_t>& AttributeAdapter<AxisSet>::get()
{
if (!m_buffer_valid)
{
for (auto elt : m_value)
{
m_buffer.push_back(elt);
}
}
return m_buffer;
}
void AttributeAdapter<AxisSet>::set(const vector<int64_t>& value)
{
m_value = AxisSet();
for (auto elt : value)
{
m_value.insert(elt);
}
m_buffer_valid = false;
}
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include <set> #include <set>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp"
namespace ngraph namespace ngraph
{ {
/// \brief A set of axes. /// \brief A set of axes.
...@@ -66,5 +68,21 @@ namespace ngraph ...@@ -66,5 +68,21 @@ namespace ngraph
} }
}; };
template <>
class AttributeAdapter<AxisSet> : public ValueReference<AxisSet>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(AxisSet& value)
: ValueReference<AxisSet>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set); std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
} }
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "ngraph/coordinate.hpp" #include "ngraph/coordinate.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate) std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate)
{ {
s << "Coordinate{"; s << "Coordinate{";
...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate) ...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate)
s << "}"; s << "}";
return s; return s;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Coordinate>::type_info;
const vector<uint64_t>& AttributeAdapter<Coordinate>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<uint64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Coordinate>::set(const vector<uint64_t>& value)
{
m_value = copy_from<Coordinate>(m_value);
m_buffer_valid = false;
}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
...@@ -73,5 +74,21 @@ namespace ngraph ...@@ -73,5 +74,21 @@ namespace ngraph
} }
}; };
template <>
class AttributeAdapter<Coordinate> : public ValueReference<Coordinate>,
public ValueAccessor<std::vector<uint64_t>>
{
public:
AttributeAdapter(Coordinate& value)
: ValueReference<Coordinate>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<uint64_t>& get() override;
void set(const std::vector<uint64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate); std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate);
} }
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff) std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff)
{ {
s << "CoordinateDiff{"; s << "CoordinateDiff{";
...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordina ...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordina
s << "}"; s << "}";
return s; return s;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<CoordinateDiff>::type_info;
const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<CoordinateDiff>::set(const vector<int64_t>& value)
{
m_value = copy_from<CoordinateDiff>(m_value);
m_buffer_valid = false;
}
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <ostream> #include <ostream>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp"
namespace ngraph namespace ngraph
{ {
/// \brief A difference (signed) of tensor element coordinates. /// \brief A difference (signed) of tensor element coordinates.
...@@ -65,5 +67,21 @@ namespace ngraph ...@@ -65,5 +67,21 @@ namespace ngraph
} }
}; };
template <>
class AttributeAdapter<CoordinateDiff> : public ValueReference<CoordinateDiff>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(CoordinateDiff& value)
: ValueReference<CoordinateDiff>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff); std::ostream& operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff);
} }
...@@ -141,4 +141,6 @@ namespace ngraph ...@@ -141,4 +141,6 @@ namespace ngraph
return allowed_values.at(type); return allowed_values.at(type);
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastSpec>::type_info;
} }
...@@ -284,4 +284,18 @@ namespace ngraph ...@@ -284,4 +284,18 @@ namespace ngraph
AutoBroadcastType type_from_string(const std::string& type) const; AutoBroadcastType type_from_string(const std::string& type) const;
}; };
} }
template <>
class AttributeAdapter<op::AutoBroadcastSpec> : public ValueReference<op::AutoBroadcastSpec>,
public ValueAccessor<void>
{
public:
AttributeAdapter(op::AutoBroadcastSpec& value)
: ValueReference<op::AutoBroadcastSpec>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastSpec>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} }
...@@ -339,3 +339,5 @@ bool PartialShape::all_non_negative() const ...@@ -339,3 +339,5 @@ bool PartialShape::all_non_negative() const
return true; return true;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<PartialShape>::type_info;
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <stddef.h> #include <stddef.h>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/dimension.hpp" #include "ngraph/dimension.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/rank.hpp" #include "ngraph/rank.hpp"
...@@ -282,4 +283,18 @@ namespace ngraph ...@@ -282,4 +283,18 @@ namespace ngraph
/// {2,3,4} /// {2,3,4}
/// \endcode /// \endcode
std::ostream& operator<<(std::ostream& str, const PartialShape& shape); std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
template <>
class AttributeAdapter<PartialShape> : public ValueReference<PartialShape>,
public ValueAccessor<void>
{
public:
AttributeAdapter(PartialShape& value)
: ValueReference<PartialShape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<PartialShape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} }
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape) std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape)
{ {
s << "Shape{"; s << "Shape{";
...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape) ...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape)
s << "}"; s << "}";
return s; return s;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Shape>::type_info;
const vector<int64_t>& AttributeAdapter<Shape>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Shape>::set(const vector<int64_t>& value)
{
m_value = copy_from<Shape>(value);
m_buffer_valid = false;
}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
...@@ -67,6 +68,22 @@ namespace ngraph ...@@ -67,6 +68,22 @@ namespace ngraph
} }
}; };
template <>
class AttributeAdapter<Shape> : public ValueReference<Shape>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(Shape& value)
: ValueReference<Shape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Shape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
/// Number of elements in spanned by a shape /// Number of elements in spanned by a shape
template <typename SHAPE_TYPE> template <typename SHAPE_TYPE>
size_t shape_size(const SHAPE_TYPE& shape) size_t shape_size(const SHAPE_TYPE& shape)
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides) std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides)
{ {
s << "Strides{"; s << "Strides{";
...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides) ...@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides)
s << "}"; s << "}";
return s; return s;
} }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Strides>::type_info;
const vector<int64_t>& AttributeAdapter<Strides>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Strides>::set(const vector<int64_t>& value)
{
m_value = copy_from<Strides>(value);
m_buffer_valid = false;
}
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <ostream> #include <ostream>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp"
namespace ngraph namespace ngraph
{ {
/// \brief Strides for a tensor. /// \brief Strides for a tensor.
...@@ -65,5 +67,21 @@ namespace ngraph ...@@ -65,5 +67,21 @@ namespace ngraph
} }
}; };
template <>
class AttributeAdapter<Strides> : public ValueReference<Strides>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(Strides& value)
: ValueReference<Strides>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Strides>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const Strides& strides); std::ostream& operator<<(std::ostream& s, const Strides& strides);
} }
...@@ -40,6 +40,8 @@ NGRAPH_API const element::Type element::u16(element::Type_t::u16); ...@@ -40,6 +40,8 @@ NGRAPH_API const element::Type element::u16(element::Type_t::u16);
NGRAPH_API const element::Type element::u32(element::Type_t::u32); NGRAPH_API const element::Type element::u32(element::Type_t::u32);
NGRAPH_API const element::Type element::u64(element::Type_t::u64); NGRAPH_API const element::Type element::u64(element::Type_t::u64);
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
class TypeInfo class TypeInfo
{ {
public: public:
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/deprecated.hpp" #include "ngraph/deprecated.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
...@@ -182,4 +183,17 @@ namespace ngraph ...@@ -182,4 +183,17 @@ namespace ngraph
std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj); std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj);
} }
template <>
class AttributeAdapter<element::Type> : public ValueReference<element::Type>,
public ValueAccessor<void>
{
public:
AttributeAdapter(element::Type& value)
: ValueReference<element::Type>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<element::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment