Unverified Commit 4dc9aa46 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Move non-primitive attribute adapters to adaptee's files (#3949)

* Move non-primitive attribute adapters to adaptee's files

* Cast in copy
parent cc754735
......@@ -18,6 +18,8 @@
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
......@@ -27,20 +29,6 @@
using namespace std;
using namespace ngraph;
namespace
{
template <typename A, typename B>
A copy_from(B& b)
{
A result(b.size());
for (int i = 0; i < b.size(); ++i)
{
result[i] = b[i];
}
return result;
}
}
namespace ngraph
{
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<float>::type_info;
......@@ -234,68 +222,4 @@ namespace ngraph
m_value = copy_from<vector<uint64_t>>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Shape>::type_info;
const vector<int64_t>& AttributeAdapter<Shape>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Shape>::set(const vector<int64_t>& value)
{
m_value = copy_from<Shape>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Strides>::type_info;
const vector<int64_t>& AttributeAdapter<Strides>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Strides>::set(const vector<int64_t>& value)
{
m_value = copy_from<Strides>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<AxisSet>::type_info;
const vector<int64_t>& AttributeAdapter<AxisSet>::get()
{
if (!m_buffer_valid)
{
for (auto elt : m_value)
{
m_buffer.push_back(elt);
}
}
return m_buffer;
}
void AttributeAdapter<AxisSet>::set(const vector<int64_t>& value)
{
m_value = AxisSet();
for (auto elt : value)
{
m_value.insert(elt);
}
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<PartialShape>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastSpec>::type_info;
}
......@@ -16,6 +16,7 @@
#pragma once
#include <type_traits>
#include <vector>
#include "ngraph/enum_names.hpp"
......@@ -272,107 +273,15 @@ namespace ngraph
void set(const std::vector<int64_t>& value) override;
};
class Shape;
template <>
class AttributeAdapter<Shape> : public ValueReference<Shape>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(Shape& value)
: ValueReference<Shape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Shape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
class Strides;
template <>
class AttributeAdapter<Strides> : public ValueReference<Strides>,
public ValueAccessor<std::vector<int64_t>>
template <typename A, typename B>
A copy_from(B& b)
{
public:
AttributeAdapter(Strides& value)
: ValueReference<Strides>(value)
A result(b.size());
for (int i = 0; i < b.size(); ++i)
{
result[i] =
static_cast<typename std::remove_reference<decltype(result[i])>::type>(b[i]);
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Strides>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
class AxisSet;
template <>
class AttributeAdapter<AxisSet> : public ValueReference<AxisSet>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(AxisSet& value)
: ValueReference<AxisSet>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
class PartialShape;
template <>
class AttributeAdapter<PartialShape> : public ValueReference<PartialShape>,
public ValueAccessor<void>
{
public:
AttributeAdapter(PartialShape& value)
: ValueReference<PartialShape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<PartialShape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace element
{
class Type;
}
template <>
class AttributeAdapter<element::Type> : public ValueReference<element::Type>,
public ValueAccessor<void>
{
public:
AttributeAdapter(element::Type& value)
: ValueReference<element::Type>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<element::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
struct AutoBroadcastSpec;
return result;
}
template <>
class AttributeAdapter<op::AutoBroadcastSpec> : public ValueReference<op::AutoBroadcastSpec>,
public ValueAccessor<void>
{
public:
AttributeAdapter(op::AutoBroadcastSpec& value)
: ValueReference<op::AutoBroadcastSpec>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastSpec>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
......@@ -17,6 +17,9 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set)
{
s << "AxisSet{";
......@@ -24,3 +27,27 @@ std::ostream& ngraph::operator<<(std::ostream& s, const AxisSet& axis_set)
s << "}";
return s;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<AxisSet>::type_info;
const vector<int64_t>& AttributeAdapter<AxisSet>::get()
{
if (!m_buffer_valid)
{
for (auto elt : m_value)
{
m_buffer.push_back(elt);
}
}
return m_buffer;
}
void AttributeAdapter<AxisSet>::set(const vector<int64_t>& value)
{
m_value = AxisSet();
for (auto elt : value)
{
m_value.insert(elt);
}
m_buffer_valid = false;
}
......@@ -21,6 +21,8 @@
#include <set>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
namespace ngraph
{
/// \brief A set of axes.
......@@ -66,5 +68,21 @@ namespace ngraph
}
};
template <>
class AttributeAdapter<AxisSet> : public ValueReference<AxisSet>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(AxisSet& value)
: ValueReference<AxisSet>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<AxisSet>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const AxisSet& axis_set);
}
......@@ -17,6 +17,9 @@
#include "ngraph/coordinate.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate)
{
s << "Coordinate{";
......@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Coordinate& coordinate)
s << "}";
return s;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Coordinate>::type_info;
const vector<uint64_t>& AttributeAdapter<Coordinate>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<uint64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Coordinate>::set(const vector<uint64_t>& value)
{
m_value = copy_from<Coordinate>(m_value);
m_buffer_valid = false;
}
......@@ -19,6 +19,7 @@
#include <algorithm>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
......@@ -73,5 +74,21 @@ namespace ngraph
}
};
template <>
class AttributeAdapter<Coordinate> : public ValueReference<Coordinate>,
public ValueAccessor<std::vector<uint64_t>>
{
public:
AttributeAdapter(Coordinate& value)
: ValueReference<Coordinate>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Coordinate>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<uint64_t>& get() override;
void set(const std::vector<uint64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate);
}
......@@ -17,6 +17,9 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff)
{
s << "CoordinateDiff{";
......@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const CoordinateDiff& coordina
s << "}";
return s;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<CoordinateDiff>::type_info;
const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<CoordinateDiff>::set(const vector<int64_t>& value)
{
m_value = copy_from<CoordinateDiff>(m_value);
m_buffer_valid = false;
}
......@@ -20,6 +20,8 @@
#include <ostream>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
namespace ngraph
{
/// \brief A difference (signed) of tensor element coordinates.
......@@ -65,5 +67,21 @@ namespace ngraph
}
};
template <>
class AttributeAdapter<CoordinateDiff> : public ValueReference<CoordinateDiff>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(CoordinateDiff& value)
: ValueReference<CoordinateDiff>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<CoordinateDiff>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const CoordinateDiff& coordinate_diff);
}
......@@ -141,4 +141,6 @@ namespace ngraph
return allowed_values.at(type);
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastSpec>::type_info;
}
......@@ -284,4 +284,18 @@ namespace ngraph
AutoBroadcastType type_from_string(const std::string& type) const;
};
}
template <>
class AttributeAdapter<op::AutoBroadcastSpec> : public ValueReference<op::AutoBroadcastSpec>,
public ValueAccessor<void>
{
public:
AttributeAdapter(op::AutoBroadcastSpec& value)
: ValueReference<op::AutoBroadcastSpec>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastSpec>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
......@@ -339,3 +339,5 @@ bool PartialShape::all_non_negative() const
return true;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<PartialShape>::type_info;
......@@ -18,6 +18,7 @@
#include <stddef.h>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/dimension.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/rank.hpp"
......@@ -282,4 +283,18 @@ namespace ngraph
/// {2,3,4}
/// \endcode
std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
template <>
class AttributeAdapter<PartialShape> : public ValueReference<PartialShape>,
public ValueAccessor<void>
{
public:
AttributeAdapter(PartialShape& value)
: ValueReference<PartialShape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<PartialShape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
......@@ -17,6 +17,9 @@
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape)
{
s << "Shape{";
......@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Shape& shape)
s << "}";
return s;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Shape>::type_info;
const vector<int64_t>& AttributeAdapter<Shape>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Shape>::set(const vector<int64_t>& value)
{
m_value = copy_from<Shape>(value);
m_buffer_valid = false;
}
......@@ -19,6 +19,7 @@
#include <cstdio>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/strides.hpp"
......@@ -67,6 +68,22 @@ namespace ngraph
}
};
template <>
class AttributeAdapter<Shape> : public ValueReference<Shape>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(Shape& value)
: ValueReference<Shape>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Shape>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
/// Number of elements in spanned by a shape
template <typename SHAPE_TYPE>
size_t shape_size(const SHAPE_TYPE& shape)
......
......@@ -17,6 +17,9 @@
#include "ngraph/strides.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides)
{
s << "Strides{";
......@@ -24,3 +27,21 @@ std::ostream& ngraph::operator<<(std::ostream& s, const Strides& strides)
s << "}";
return s;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Strides>::type_info;
const vector<int64_t>& AttributeAdapter<Strides>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Strides>::set(const vector<int64_t>& value)
{
m_value = copy_from<Strides>(value);
m_buffer_valid = false;
}
......@@ -20,6 +20,8 @@
#include <ostream>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
namespace ngraph
{
/// \brief Strides for a tensor.
......@@ -65,5 +67,21 @@ namespace ngraph
}
};
template <>
class AttributeAdapter<Strides> : public ValueReference<Strides>,
public ValueAccessor<std::vector<int64_t>>
{
public:
AttributeAdapter(Strides& value)
: ValueReference<Strides>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Strides>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<int64_t>& get() override;
void set(const std::vector<int64_t>& value) override;
};
std::ostream& operator<<(std::ostream& s, const Strides& strides);
}
......@@ -40,6 +40,8 @@ NGRAPH_API const element::Type element::u16(element::Type_t::u16);
NGRAPH_API const element::Type element::u32(element::Type_t::u32);
NGRAPH_API const element::Type element::u64(element::Type_t::u64);
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
class TypeInfo
{
public:
......
......@@ -26,6 +26,7 @@
#include <string>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/deprecated.hpp"
#include "ngraph/except.hpp"
#include "ngraph/ngraph_visibility.hpp"
......@@ -182,4 +183,17 @@ namespace ngraph
std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj);
}
template <>
class AttributeAdapter<element::Type> : public ValueReference<element::Type>,
public ValueAccessor<void>
{
public:
AttributeAdapter(element::Type& value)
: ValueReference<element::Type>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<element::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment