Unverified Commit bc1ca25b authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Attribute visitor (#3579)

* Sketch of attribute walker

* Review comments

* merge error?

* Remove unused method

* simplify, make some ser tests work

* Don't look for keys that aren't there

* Factory registry, more ops visited, generic ser/dser start

* More merge

* cleanup

* Adapter for enums

* Compiler error

* Test of user-defined op

* Simplify enum name pairing

* Update distributed.hpp

* Review comments

* compiler error

* Direct access to non-primitive types from adapters

* Define and export type info

* attr enums, AvgPool*, vectors

* Cleanup

* some comments

* Allow type info to be used as a key.

* Don't leave output serialization shapes set.

* Auto adapter

* More ops, adapters

* Missing symbol

* Remove PartialShape and element::Type methods from visitor

* Fix type info

* Remove unused variable

* Simplify

* namespace error

* exports

* Uniform names

* Some better names

* More name cleanup, simplify visitor implementation

* Fix template, add test

* Revert serializer

* Add instantiations

* Work-around gcc issue

* VS exports

* VS exports

* windows export

* vs

* vs

* vs

* vs

* Simplify

* vs

* vs

* Add some missing attributes

* Missing factories

* Merge error

* Fix Add factories

* Missed type
parent 5fa5854c
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
set (SRC set (SRC
assertion.hpp assertion.hpp
attribute_adapter.cpp
attribute_adapter.hpp
attribute_visitor.hpp
autodiff/adjoints.cpp autodiff/adjoints.cpp
autodiff/adjoints.hpp autodiff/adjoints.hpp
axis_set.cpp axis_set.cpp
...@@ -77,7 +80,10 @@ set (SRC ...@@ -77,7 +80,10 @@ set (SRC
dimension.hpp dimension.hpp
distributed.cpp distributed.cpp
distributed.hpp distributed.hpp
enum_names.hpp
except.hpp except.hpp
factory.cpp
factory.hpp
file_util.cpp file_util.cpp
file_util.hpp file_util.hpp
function.cpp function.cpp
...@@ -555,6 +561,7 @@ set (SRC ...@@ -555,6 +561,7 @@ set (SRC
type/float16.cpp type/float16.cpp
type/float16.hpp type/float16.hpp
type/element_type.cpp type/element_type.cpp
type.hpp
util.cpp util.cpp
util.hpp util.hpp
validation_util.cpp validation_util.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp"
using namespace std;
using namespace ngraph;
namespace
{
template <typename A, typename B>
A copy_from(B& b)
{
A result(b.size());
for (int i = 0; i < b.size(); ++i)
{
result[i] = b[i];
}
return result;
}
}
namespace ngraph
{
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<float>::type_info;
const double& AttributeAdapter<float>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<float>::set(const double& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<double>::type_info;
const double& AttributeAdapter<double>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<double>::set(const double& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int8_t>::type_info;
const int64_t& AttributeAdapter<int8_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int8_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int16_t>::type_info;
const int64_t& AttributeAdapter<int16_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int16_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int32_t>::type_info;
const int64_t& AttributeAdapter<int32_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int32_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int64_t>::type_info;
const int64_t& AttributeAdapter<int64_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int64_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint8_t>::type_info;
const int64_t& AttributeAdapter<uint8_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint8_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint16_t>::type_info;
const int64_t& AttributeAdapter<uint16_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint16_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint32_t>::type_info;
const int64_t& AttributeAdapter<uint32_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint32_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint64_t>::type_info;
const int64_t& AttributeAdapter<uint64_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint64_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<vector<int64_t>>::type_info;
const vector<int64_t>& AttributeAdapter<vector<int64_t>>::get() { return m_value; }
void AttributeAdapter<vector<int64_t>>::set(const vector<int64_t>& value) { m_value = value; }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<vector<uint64_t>>::type_info;
const vector<int64_t>& AttributeAdapter<vector<uint64_t>>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<vector<uint64_t>>::set(const vector<int64_t>& value)
{
m_value = copy_from<vector<uint64_t>>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Shape>::type_info;
const vector<int64_t>& AttributeAdapter<Shape>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Shape>::set(const vector<int64_t>& value)
{
m_value = copy_from<Shape>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Strides>::type_info;
const vector<int64_t>& AttributeAdapter<Strides>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Strides>::set(const vector<int64_t>& value)
{
m_value = copy_from<Strides>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<AxisSet>::type_info;
const vector<int64_t>& AttributeAdapter<AxisSet>::get()
{
if (!m_buffer_valid)
{
for (auto elt : m_value)
{
m_buffer.push_back(elt);
}
}
return m_buffer;
}
void AttributeAdapter<AxisSet>::set(const vector<int64_t>& value)
{
m_value = AxisSet();
for (auto elt : value)
{
m_value.insert(elt);
}
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<PartialShape>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastSpec>::type_info;
}
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <utility>
#include "ngraph/partial_shape.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
template <typename T>
class ValueAccessor;
/// \brief Visits the attributes of a node.
///
/// Attributes are the values set when building a graph which are not
/// computed as the graph executes. Values computed from the graph topology and attributes
/// during compilation are not attributes.
class AttributeVisitor
{
public:
virtual ~AttributeVisitor() {}
// Must implement these methods
virtual void on_attribute(const std::string& name, std::string& value) = 0;
virtual void on_attribute(const std::string& name, bool& value) = 0;
virtual void on_adapter(const std::string& name, ValueAccessor<void>& adapter) = 0;
// The remaining adapter methods fall back on the void adapter if not implemented
virtual void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
};
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
// Use an adapter for non-primitive types
template <typename T>
// typename std::enable_if<std::is_class<T>::value, void>::type
void on_attribute(const std::string& name, T& value)
{
AttributeAdapter<T> adapter(value);
on_adapter(name, adapter);
}
};
}
...@@ -19,28 +19,30 @@ ...@@ -19,28 +19,30 @@
#include "ngraph/distributed/null.hpp" #include "ngraph/distributed/null.hpp"
#include "ngraph/distributed/open_mpi.hpp" #include "ngraph/distributed/open_mpi.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/type.hpp"
using namespace ngraph; using namespace ngraph;
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj) namespace ngraph
{ {
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) template <>
#pragma GCC diagnostic push EnumNames<reduction::Type>& EnumNames<reduction::Type>::get()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (obj)
{ {
case reduction::Type::SUM: out << "SUM"; break; static auto enum_names = EnumNames<reduction::Type>("reduction::Type",
case reduction::Type::PROD: out << "PROD"; break; {{"SUM", reduction::Type::SUM},
case reduction::Type::MIN: out << "MIN"; break; {"PROD", reduction::Type::PROD},
case reduction::Type::MAX: out << "MAX"; break; {"MIN", reduction::Type::MIN},
{"MAX", reduction::Type::MAX}});
return enum_names;
} }
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop constexpr DiscreteTypeInfo AttributeAdapter<reduction::Type>::type_info;
#endif }
return out;
}; std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{
return out << as_string(obj);
}
static std::unique_ptr<DistributedInterface> s_distributed_interface; static std::unique_ptr<DistributedInterface> s_distributed_interface;
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
...@@ -37,6 +39,20 @@ namespace ngraph ...@@ -37,6 +39,20 @@ namespace ngraph
std::ostream& operator<<(std::ostream& out, const Type& obj); std::ostream& operator<<(std::ostream& out, const Type& obj);
} }
template <>
class AttributeAdapter<reduction::Type> : public EnumAttributeAdapterBase<reduction::Type>
{
public:
AttributeAdapter(reduction::Type& value)
: EnumAttributeAdapterBase<reduction::Type>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
class DistributedInterface class DistributedInterface
{ {
public: public:
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <utility>
#include "ngraph/check.hpp"
namespace ngraph
{
/// Uses a pairings defined by EnumTypes::get() to convert between strings
/// and enum values.
template <typename EnumType>
class EnumNames
{
public:
/// Converts strings to enum values
static EnumType as_enum(const std::string& name)
{
for (auto p : get().m_string_enums)
{
if (p.first == name)
{
return p.second;
}
}
NGRAPH_CHECK(false, "\"", name, "\"", " is not a member of enum ", get().m_enum_name);
}
/// Converts enum values to strings
static const std::string& as_string(EnumType e)
{
for (auto& p : get().m_string_enums)
{
if (p.second == e)
{
return p.first;
}
}
NGRAPH_CHECK(false, " invalid member of enum ", get().m_enum_name);
}
private:
/// Creates the mapping.
EnumNames(const std::string& enum_name,
const std::vector<std::pair<std::string, EnumType>> string_enums)
: m_enum_name(enum_name)
, m_string_enums(string_enums)
{
}
/// Must be defined to returns a singleton for each supported enum class
static EnumNames<EnumType>& get();
const std::string m_enum_name;
std::vector<std::pair<std::string, EnumType>> m_string_enums;
};
/// Returns the enum value matching the string
template <typename Type, typename Value>
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type
as_enum(const Value& value)
{
return EnumNames<Type>::as_enum(value);
}
/// Returns the string matching the enum value
template <typename Value>
const std::string& as_string(Value value)
{
return EnumNames<Value>::as_string(value);
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <mutex>
#include "ngraph/factory.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/parameter.hpp"
using namespace std;
namespace ngraph
{
mutex& get_registry_mutex()
{
static mutex registry_mutex;
return registry_mutex;
}
template <>
FactoryRegistry<Node>& FactoryRegistry<Node>::get()
{
static FactoryRegistry<Node> registry;
static mutex init_guard;
// TODO: Add a lock
if (registry.m_factory_map.size() == 0)
{
lock_guard<mutex> guard(init_guard);
if (registry.m_factory_map.size() == 0)
{
registry.register_factory<op::Abs>();
registry.register_factory<op::Acos>();
registry.register_factory<op::v0::Add>();
registry.register_factory<op::v1::Add>();
registry.register_factory<op::All>();
registry.register_factory<op::AllReduce>();
registry.register_factory<op::And>();
registry.register_factory<op::Any>();
registry.register_factory<op::ArgMax>();
registry.register_factory<op::ArgMin>();
registry.register_factory<op::v0::AvgPool>();
registry.register_factory<op::v0::AvgPoolBackprop>();
registry.register_factory<op::v1::AvgPool>();
registry.register_factory<op::v1::AvgPoolBackprop>();
registry.register_factory<op::BatchNormInference>();
registry.register_factory<op::BatchNormTraining>();
registry.register_factory<op::BatchNormTrainingBackprop>();
registry.register_factory<op::BroadcastDistributed>();
registry.register_factory<op::v0::Broadcast>();
registry.register_factory<op::v0::BroadcastLike>();
registry.register_factory<op::v1::Broadcast>();
registry.register_factory<op::Ceiling>();
registry.register_factory<op::Concat>();
registry.register_factory<op::v1::LogicalAnd>();
registry.register_factory<op::Parameter>();
}
}
return registry;
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <map>
#include <mutex>
namespace ngraph
{
std::mutex& get_registry_mutex();
template <typename T>
class FactoryRegistry
{
public:
using Factory = std::function<T*()>;
using FactoryMap = std::map<decltype(T::type_info), Factory>;
template <typename U>
void register_factory()
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[U::type_info] = []() { return new U(); };
}
bool has_factory(const decltype(T::type_info) & info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end();
}
T* create(const decltype(T::type_info) & info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(info);
return it == m_factory_map.end() ? nullptr : it->second();
}
static FactoryRegistry<T>& get();
protected:
// Need a Compare on type_info
FactoryMap m_factory_map;
};
}
...@@ -61,6 +61,8 @@ namespace ngraph ...@@ -61,6 +61,8 @@ namespace ngraph
/// \brief Convenience functions that create addional graph nodes to implement commonly-used /// \brief Convenience functions that create addional graph nodes to implement commonly-used
/// recipes, for example auto-broadcast. /// recipes, for example auto-broadcast.
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/dequantize_builder.hpp" #include "ngraph/builder/dequantize_builder.hpp"
#include "ngraph/builder/numpy_transpose.hpp" #include "ngraph/builder/numpy_transpose.hpp"
...@@ -80,6 +82,7 @@ namespace ngraph ...@@ -80,6 +82,7 @@ namespace ngraph
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/dimension.hpp" #include "ngraph/dimension.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/factory.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/lambda.hpp" #include "ngraph/lambda.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -221,4 +224,5 @@ namespace ngraph ...@@ -221,4 +224,5 @@ namespace ngraph
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
#include "ngraph/specialize_function.hpp" #include "ngraph/specialize_function.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -50,6 +50,7 @@ namespace ngraph ...@@ -50,6 +50,7 @@ namespace ngraph
template <typename NodeType> template <typename NodeType>
class Output; class Output;
class AttributeVisitor;
class Variant; class Variant;
class Node; class Node;
using NodeVector = std::vector<std::shared_ptr<Node>>; using NodeVector = std::vector<std::shared_ptr<Node>>;
...@@ -110,13 +111,14 @@ namespace ngraph ...@@ -110,13 +111,14 @@ namespace ngraph
template <typename NodeType> template <typename NodeType>
friend class Output; friend class Output;
protected: public:
/// Throws if the node is invalid. /// Throws if the node is invalid.
virtual void validate_and_infer_types(); virtual void validate_and_infer_types();
// Called in constructors during transition // Called in constructors during transition
void constructor_validate_and_infer_types(); void constructor_validate_and_infer_types();
protected:
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args( std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec()); const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
void validate_and_infer_elementwise_arithmetic( void validate_and_infer_elementwise_arithmetic(
...@@ -157,6 +159,7 @@ namespace ngraph ...@@ -157,6 +159,7 @@ namespace ngraph
virtual ~Node(); virtual ~Node();
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
virtual bool is_unary_elementwise_arithmetic() const { return false; } virtual bool is_unary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_arithmetic() const { return false; } virtual bool is_binary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_comparison() const { return false; } virtual bool is_binary_elementwise_comparison() const { return false; }
......
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
Abs() = default; Abs() = default;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
......
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Acos(const Output<Node>& arg); Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
......
...@@ -37,6 +37,12 @@ shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) con ...@@ -37,6 +37,12 @@ shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) con
return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob());
} }
bool op::v0::Add::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}
void op::v0::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
...@@ -70,6 +76,12 @@ op::v1::Add::Add(const Output<Node>& arg0, ...@@ -70,6 +76,12 @@ op::v1::Add::Add(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::Add::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -52,7 +52,7 @@ namespace ngraph ...@@ -52,7 +52,7 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -90,7 +90,7 @@ namespace ngraph ...@@ -90,7 +90,7 @@ namespace ngraph
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
protected: protected:
......
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const Output<Node>& reduction_axes); All(const Output<Node>& arg, const Output<Node>& reduction_axes);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for All. /// \return The default value for All.
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/type.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -47,6 +49,12 @@ shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) c ...@@ -47,6 +49,12 @@ shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) c
return make_shared<AllReduce>(new_args.at(0), get_reduce_type()); return make_shared<AllReduce>(new_args.at(0), get_reduce_type());
} }
bool op::AllReduce::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("reduce_type", m_reduce_type);
return true;
}
reduction::Type op::AllReduce::get_reduce_type() const reduction::Type op::AllReduce::get_reduce_type() const
{ {
return m_reduce_type; return m_reduce_type;
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const; reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type); void set_reduce_type(reduction::Type reduce_type);
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
reduction::Type m_reduce_type{reduction::Type::SUM}; reduction::Type m_reduce_type{reduction::Type::SUM};
......
...@@ -29,6 +29,12 @@ op::v1::LogicalAnd::LogicalAnd(const Output<Node>& arg0, ...@@ -29,6 +29,12 @@ op::v1::LogicalAnd::LogicalAnd(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::LogicalAnd::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseLogical::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -45,6 +51,12 @@ op::v0::And::And(const Output<Node>& arg0, ...@@ -45,6 +51,12 @@ op::v0::And::And(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v0::And::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseLogical::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -53,7 +53,7 @@ namespace ngraph ...@@ -53,7 +53,7 @@ namespace ngraph
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
}; };
} // namespace v0 } // namespace v0
...@@ -85,7 +85,7 @@ namespace ngraph ...@@ -85,7 +85,7 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
}; };
} }
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
/// \return The default value for Any. /// \return The default value for Any.
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
......
...@@ -28,6 +28,12 @@ op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& in ...@@ -28,6 +28,12 @@ op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& in
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::ArgMax::visit_attributes(AttributeVisitor& visitor)
{
IndexReduction::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
} }
......
...@@ -28,6 +28,12 @@ op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& in ...@@ -28,6 +28,12 @@ op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& in
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::ArgMin::visit_attributes(AttributeVisitor& visitor)
{
IndexReduction::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
} }
......
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -78,6 +78,19 @@ op::v0::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -78,6 +78,19 @@ op::v0::AvgPool::AvgPool(const Output<Node>& arg,
{ {
} }
bool op::v0::AvgPool::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("window_shape", m_window_shape);
visitor.on_attribute("window_movement_strides", m_window_movement_strides);
visitor.on_attribute("padding_below", m_padding_below);
visitor.on_attribute("padding_above", m_padding_above);
visitor.on_attribute("include_padding_in_avg_computation",
m_include_padding_in_avg_computation);
visitor.on_attribute("pad_type", m_pad_type);
visitor.on_attribute("ceil_mode", m_ceil_mode);
return true;
}
void op::v0::AvgPool::validate_and_infer_types() void op::v0::AvgPool::validate_and_infer_types()
{ {
if (0 == m_window_movement_strides.size()) if (0 == m_window_movement_strides.size())
...@@ -251,6 +264,18 @@ op::v0::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -251,6 +264,18 @@ op::v0::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v0::AvgPoolBackprop::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("forward_arg_shape", m_forward_arg_shape);
visitor.on_attribute("window_shape", m_window_shape);
visitor.on_attribute("window_movement_strides", m_window_movement_strides);
visitor.on_attribute("padding_below", m_padding_below);
visitor.on_attribute("padding_above", m_padding_above);
visitor.on_attribute("include_padding_in_avg_computation",
m_include_padding_in_avg_computation);
return true;
}
void op::v0::AvgPoolBackprop::validate_and_infer_types() void op::v0::AvgPoolBackprop::validate_and_infer_types()
{ {
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
...@@ -420,6 +445,18 @@ op::v1::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -420,6 +445,18 @@ op::v1::AvgPool::AvgPool(const Output<Node>& arg,
{ {
} }
bool op::v1::AvgPool::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("kernel", m_kernel);
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("exclude_pad", m_exclude_pad);
visitor.on_attribute("auto_pad", m_auto_pad);
visitor.on_attribute("rounding_type", m_rounding_type);
return true;
}
void op::v1::AvgPool::validate_and_infer_types() void op::v1::AvgPool::validate_and_infer_types()
{ {
if (0 == m_strides.size()) if (0 == m_strides.size())
...@@ -575,6 +612,16 @@ op::v1::AvgPoolBackprop::AvgPoolBackprop(const Output<Node>& delta, ...@@ -575,6 +612,16 @@ op::v1::AvgPoolBackprop::AvgPoolBackprop(const Output<Node>& delta,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::AvgPoolBackprop::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("kernel", m_kernel);
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("exclude_pad", m_exclude_pad);
return true;
}
const Shape op::v1::AvgPoolBackprop::get_forward_arg_shape() const const Shape op::v1::AvgPoolBackprop::get_forward_arg_shape() const
{ {
Shape shape; Shape shape;
......
...@@ -130,6 +130,8 @@ namespace ngraph ...@@ -130,6 +130,8 @@ namespace ngraph
/// `[n]` /// `[n]`
AvgPool(const Output<Node>& arg, const Shape& window_shape); AvgPool(const Output<Node>& arg, const Shape& window_shape);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -187,6 +189,7 @@ namespace ngraph ...@@ -187,6 +189,7 @@ namespace ngraph
bool include_padding_in_avg_computation); bool include_padding_in_avg_computation);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -281,6 +284,7 @@ namespace ngraph ...@@ -281,6 +284,7 @@ namespace ngraph
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -337,6 +341,7 @@ namespace ngraph ...@@ -337,6 +341,7 @@ namespace ngraph
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -46,6 +46,12 @@ op::BatchNormTraining::BatchNormTraining(double eps, ...@@ -46,6 +46,12 @@ op::BatchNormTraining::BatchNormTraining(double eps,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BatchNormTraining::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormTraining::validate_and_infer_types() void op::BatchNormTraining::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
...@@ -129,6 +135,12 @@ op::BatchNormInference::BatchNormInference(double eps, ...@@ -129,6 +135,12 @@ op::BatchNormInference::BatchNormInference(double eps,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BatchNormInference::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormInference::validate_and_infer_types() void op::BatchNormInference::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
...@@ -191,6 +203,12 @@ op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, ...@@ -191,6 +203,12 @@ op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BatchNormTrainingBackprop::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormTrainingBackprop::validate_and_infer_types() void op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
......
...@@ -43,6 +43,8 @@ namespace ngraph ...@@ -43,6 +43,8 @@ namespace ngraph
const Output<Node>& beta, const Output<Node>& beta,
double epsilon); double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
/// In this version of BatchNorm: /// In this version of BatchNorm:
/// ///
...@@ -108,6 +110,8 @@ namespace ngraph ...@@ -108,6 +110,8 @@ namespace ngraph
const Output<Node>& variance, const Output<Node>& variance,
double epsilon); double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
/// In this version of BatchNorm: /// In this version of BatchNorm:
/// ///
...@@ -184,6 +188,7 @@ namespace ngraph ...@@ -184,6 +188,7 @@ namespace ngraph
const Output<Node>& delta); const Output<Node>& delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; } void set_eps_value(double epsilon) { m_epsilon = epsilon; }
......
...@@ -46,6 +46,12 @@ op::v1::Broadcast::Broadcast(const Output<Node>& arg, ...@@ -46,6 +46,12 @@ op::v1::Broadcast::Broadcast(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::Broadcast::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("broadcast_spec", m_broadcast_spec);
return true;
}
std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const
{ {
AxisSet broadcast_axes; AxisSet broadcast_axes;
...@@ -286,6 +292,13 @@ op::v0::Broadcast::Broadcast(const Output<Node>& arg, ...@@ -286,6 +292,13 @@ op::v0::Broadcast::Broadcast(const Output<Node>& arg,
{ {
} }
bool op::v0::Broadcast::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("shape", m_shape);
visitor.on_attribute("broadcast_axes", m_broadcast_axes);
return true;
}
void op::v0::Broadcast::validate_and_infer_types() void op::v0::Broadcast::validate_and_infer_types()
{ {
infer_shape(); infer_shape();
...@@ -355,6 +368,14 @@ op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg, ...@@ -355,6 +368,14 @@ op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("shape", m_shape);
visitor.on_attribute("broadcast_axes", m_broadcast_axes);
visitor.on_attribute("initial_broadcast_axes", m_initial_broadcast_axes);
return true;
}
shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
......
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
Broadcast(const Output<Node>& arg, Broadcast(const Output<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes); const AxisSet& broadcast_axes);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -94,7 +94,7 @@ namespace ngraph ...@@ -94,7 +94,7 @@ namespace ngraph
BroadcastLike(const Output<Node>& arg, BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg, const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes); const AxisSet& initial_broadcast_axes);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -154,7 +154,7 @@ namespace ngraph ...@@ -154,7 +154,7 @@ namespace ngraph
const Output<Node>& target_shape, const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec = const AutoBroadcastSpec& broadcast_spec =
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -21,13 +21,19 @@ using namespace ngraph; ...@@ -21,13 +21,19 @@ using namespace ngraph;
constexpr NodeTypeInfo op::BroadcastDistributed::type_info; constexpr NodeTypeInfo op::BroadcastDistributed::type_info;
op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int root_id) op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int64_t root_id)
: Op({arg}) : Op({arg})
, m_root_id(root_id) , m_root_id(root_id)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BroadcastDistributed::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("root_id", m_root_id);
return true;
}
void op::BroadcastDistributed::validate_and_infer_types() void op::BroadcastDistributed::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -47,12 +53,12 @@ shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& ...@@ -47,12 +53,12 @@ shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector&
return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id); return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id);
} }
int op::BroadcastDistributed::get_root_id() const int64_t op::BroadcastDistributed::get_root_id() const
{ {
return m_root_id; return m_root_id;
} }
void op::BroadcastDistributed::set_root_id(int root_id) void op::BroadcastDistributed::set_root_id(int64_t root_id)
{ {
m_root_id = root_id; m_root_id = root_id;
} }
...@@ -31,17 +31,17 @@ namespace ngraph ...@@ -31,17 +31,17 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"BroadcastDistributed", 0}; static constexpr NodeTypeInfo type_info{"BroadcastDistributed", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
BroadcastDistributed() = default; BroadcastDistributed() = default;
BroadcastDistributed(const Output<Node>& arg, int root_id = 0); BroadcastDistributed(const Output<Node>& arg, int64_t root_id = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const; int64_t get_root_id() const;
void set_root_id(int root_id); void set_root_id(int64_t root_id);
private: private:
int m_root_id; int64_t m_root_id;
}; };
} }
} }
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Ceiling(const Output<Node>& arg); Ceiling(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -36,6 +36,12 @@ op::Concat::Concat(const NodeVector& args, int64_t axis) ...@@ -36,6 +36,12 @@ op::Concat::Concat(const NodeVector& args, int64_t axis)
{ {
} }
bool op::Concat::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
return true;
}
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required."); NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors. /// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, int64_t axis); Concat(const NodeVector& args, int64_t axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
using namespace std; using namespace std;
...@@ -34,6 +35,14 @@ op::Parameter::Parameter(const element::Type& element_type, ...@@ -34,6 +35,14 @@ op::Parameter::Parameter(const element::Type& element_type,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::Parameter::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("cacheable", m_cacheable);
visitor.on_attribute("shape", m_partial_shape);
visitor.on_attribute("element_type", m_element_type);
return true;
}
void op::Parameter::validate_and_infer_types() void op::Parameter::validate_and_infer_types()
{ {
Op::validate_and_infer_types(); Op::validate_and_infer_types();
......
...@@ -49,6 +49,8 @@ namespace ngraph ...@@ -49,6 +49,8 @@ namespace ngraph
const PartialShape& pshape, const PartialShape& pshape,
const bool cacheable = false); const bool cacheable = false);
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_parameter() const override { return true; } bool is_parameter() const override { return true; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -15,17 +15,112 @@ ...@@ -15,17 +15,112 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/enum_names.hpp"
using namespace ngraph; using namespace ngraph;
std::ostream& op::operator<<(std::ostream& s, const op::AutoBroadcastType& type) namespace ngraph
{ {
switch (type) template <>
EnumNames<op::PadMode>& EnumNames<op::PadMode>::get()
{ {
case op::AutoBroadcastType::NONE: s << "NONE"; break; static auto enum_names = EnumNames<op::PadMode>("op::PadMode",
case op::AutoBroadcastType::NUMPY: s << "NUMPY"; break; {{"CONSTANT", op::PadMode::CONSTANT},
case op::AutoBroadcastType::PDPD: s << "PDPD"; break; {"EDGE", op::PadMode::EDGE},
default: s << "Undefined Type"; {"REFLECT", op::PadMode::REFLECT},
{"SYMMETRIC", op::PadMode::SYMMETRIC}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::PadMode>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::PadMode& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::PadType>& EnumNames<op::PadType>::get()
{
static auto enum_names = EnumNames<op::PadType>("op::PadType",
{{"EXPLICIT", op::PadType::EXPLICIT},
{"SAME_LOWER", op::PadType::SAME_LOWER},
{"SAME_UPPER", op::PadType::SAME_UPPER},
{"VALID", op::PadType::VALID}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::PadType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::PadType& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::RoundingType>& EnumNames<op::RoundingType>::get()
{
static auto enum_names = EnumNames<op::RoundingType>(
"op::RoundingType",
{{"FLOOR", op::RoundingType::FLOOR}, {"CEIL", op::RoundingType::CEIL}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::RoundingType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::RoundingType& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::AutoBroadcastType>& EnumNames<op::AutoBroadcastType>::get()
{
static auto enum_names =
EnumNames<op::AutoBroadcastType>("op::AutoBroadcastType",
{{"NONE", op::AutoBroadcastType::NONE},
{"NUMPY", op::AutoBroadcastType::NUMPY},
{"PDPD", op::AutoBroadcastType::PDPD}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::AutoBroadcastType& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::EpsMode>& EnumNames<op::EpsMode>::get()
{
static auto enum_names = EnumNames<op::EpsMode>(
"op::EpsMode", {{"ADD", op::EpsMode::ADD}, {"MAX", op::EpsMode::MAX}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::EpsMode>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::EpsMode& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get()
{
static auto enum_names =
EnumNames<op::TopKSortType>("op::TopKSortType",
{{"NONE", op::TopKSortType::NONE},
{"SORT_INDICES", op::TopKSortType::SORT_INDICES},
{"SORT_VALUES", op::TopKSortType::SORT_VALUES}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::TopKSortType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::TopKSortType& type)
{
return s << as_string(type);
} }
return s;
} }
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include <cstddef> #include <cstddef>
#include <ostream> #include <ostream>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/type.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
...@@ -32,6 +35,25 @@ namespace ngraph ...@@ -32,6 +35,25 @@ namespace ngraph
SYMMETRIC SYMMETRIC
}; };
std::ostream& operator<<(std::ostream& s, const PadMode& type);
}
template <>
class AttributeAdapter<op::PadMode> : public EnumAttributeAdapterBase<op::PadMode>
{
public:
AttributeAdapter(op::PadMode& value)
: EnumAttributeAdapterBase<op::PadMode>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PadMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Padding Type used for `Convolution` and `Pooling` /// \brief Padding Type used for `Convolution` and `Pooling`
/// ///
/// Follows ONNX padding type definitions /// Follows ONNX padding type definitions
...@@ -54,6 +76,25 @@ namespace ngraph ...@@ -54,6 +76,25 @@ namespace ngraph
NOTSET = EXPLICIT, NOTSET = EXPLICIT,
}; };
std::ostream& operator<<(std::ostream& s, const PadType& type);
}
template <>
class AttributeAdapter<op::PadType> : public EnumAttributeAdapterBase<op::PadType>
{
public:
AttributeAdapter(op::PadType& value)
: EnumAttributeAdapterBase<op::PadType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PadType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Rounding Type used for `Pooling` operators. /// \brief Rounding Type used for `Pooling` operators.
enum class RoundingType enum class RoundingType
{ {
...@@ -61,6 +102,25 @@ namespace ngraph ...@@ -61,6 +102,25 @@ namespace ngraph
CEIL = 1, CEIL = 1,
}; };
std::ostream& operator<<(std::ostream& s, const RoundingType& type);
}
template <>
class AttributeAdapter<op::RoundingType> : public EnumAttributeAdapterBase<op::RoundingType>
{
public:
AttributeAdapter(op::RoundingType& value)
: EnumAttributeAdapterBase<op::RoundingType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::RoundingType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Specifies the algorithm to use for implicit broadcasting of a tensor /// \brief Specifies the algorithm to use for implicit broadcasting of a tensor
/// to align with another tensor /// to align with another tensor
/// ///
...@@ -107,6 +167,26 @@ namespace ngraph ...@@ -107,6 +167,26 @@ namespace ngraph
PDPD PDPD
}; };
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
}
template <>
class AttributeAdapter<op::AutoBroadcastType>
: public EnumAttributeAdapterBase<op::AutoBroadcastType>
{
public:
AttributeAdapter(op::AutoBroadcastType& value)
: EnumAttributeAdapterBase<op::AutoBroadcastType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Specifies how eps is combined with L2 value /// \brief Specifies how eps is combined with L2 value
enum class EpsMode enum class EpsMode
{ {
...@@ -116,6 +196,25 @@ namespace ngraph ...@@ -116,6 +196,25 @@ namespace ngraph
MAX MAX
}; };
std::ostream& operator<<(std::ostream& s, const EpsMode& type);
}
template <>
class AttributeAdapter<op::EpsMode> : public EnumAttributeAdapterBase<op::EpsMode>
{
public:
AttributeAdapter(op::EpsMode& value)
: EnumAttributeAdapterBase<op::EpsMode>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::EpsMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
enum class TopKSortType enum class TopKSortType
{ {
// Returned values are not sorted // Returned values are not sorted
...@@ -125,7 +224,25 @@ namespace ngraph ...@@ -125,7 +224,25 @@ namespace ngraph
// Sort result based on element values // Sort result based on element values
SORT_VALUES, SORT_VALUES,
}; };
std::ostream& operator<<(std::ostream& s, const TopKSortType& type);
}
template <>
class AttributeAdapter<op::TopKSortType> : public EnumAttributeAdapterBase<op::TopKSortType>
{
public:
AttributeAdapter(op::TopKSortType& value)
: EnumAttributeAdapterBase<op::TopKSortType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::TopKSortType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Implicit broadcast specification /// \brief Implicit broadcast specification
struct AutoBroadcastSpec struct AutoBroadcastSpec
{ {
...@@ -153,7 +270,5 @@ namespace ngraph ...@@ -153,7 +270,5 @@ namespace ngraph
return a.m_type == m_type && a.m_axis == m_axis; return a.m_type == m_type && a.m_axis == m_axis;
} }
}; };
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
} }
} }
...@@ -54,3 +54,9 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types() ...@@ -54,3 +54,9 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
{ {
validate_and_infer_elementwise_arithmetic(m_autob); validate_and_infer_elementwise_arithmetic(m_autob);
} }
bool op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
return true;
}
...@@ -86,6 +86,8 @@ namespace ngraph ...@@ -86,6 +86,8 @@ namespace ngraph
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool is_binary_elementwise_arithmetic() const override { return true; } bool is_binary_elementwise_arithmetic() const override { return true; }
bool supports_auto_broadcast() const override { return true; } bool supports_auto_broadcast() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -55,3 +55,9 @@ void op::util::BinaryElementwiseComparison::validate_and_infer_types() ...@@ -55,3 +55,9 @@ void op::util::BinaryElementwiseComparison::validate_and_infer_types()
set_output_type(0, element::boolean, args_pshape); set_output_type(0, element::boolean, args_pshape);
} }
bool op::util::BinaryElementwiseComparison::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
return true;
}
...@@ -92,6 +92,8 @@ namespace ngraph ...@@ -92,6 +92,8 @@ namespace ngraph
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; } bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_comparison() const override { return true; } bool is_binary_elementwise_comparison() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -52,3 +52,9 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types() ...@@ -52,3 +52,9 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types()
{ {
validate_and_infer_elementwise_logical(m_autob); validate_and_infer_elementwise_logical(m_autob);
} }
bool op::util::BinaryElementwiseLogical::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
return true;
}
...@@ -88,6 +88,8 @@ namespace ngraph ...@@ -88,6 +88,8 @@ namespace ngraph
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; } bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_logical() const override { return true; } bool is_binary_elementwise_logical() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -26,7 +26,7 @@ op::util::IndexReduction::IndexReduction() ...@@ -26,7 +26,7 @@ op::util::IndexReduction::IndexReduction()
} }
op::util::IndexReduction::IndexReduction(const Output<Node>& arg, op::util::IndexReduction::IndexReduction(const Output<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op({arg}) : Op({arg})
{ {
...@@ -35,7 +35,7 @@ op::util::IndexReduction::IndexReduction(const Output<Node>& arg, ...@@ -35,7 +35,7 @@ op::util::IndexReduction::IndexReduction(const Output<Node>& arg,
} }
op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg, op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op(check_single_output_args({arg})) : Op(check_single_output_args({arg}))
{ {
...@@ -45,7 +45,7 @@ op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg, ...@@ -45,7 +45,7 @@ op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg,
op::util::IndexReduction::IndexReduction(const std::string& node_type, op::util::IndexReduction::IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op(node_type, check_single_output_args({arg})) : Op(node_type, check_single_output_args({arg}))
{ {
...@@ -53,11 +53,11 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type, ...@@ -53,11 +53,11 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type,
set_index_element_type(index_element_type); set_index_element_type(index_element_type);
} }
size_t op::util::IndexReduction::get_reduction_axis() const uint64_t op::util::IndexReduction::get_reduction_axis() const
{ {
return m_axis; return m_axis;
} }
void op::util::IndexReduction::set_reduction_axis(size_t value) void op::util::IndexReduction::set_reduction_axis(uint64_t value)
{ {
m_axis = value; m_axis = value;
} }
...@@ -125,3 +125,10 @@ void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& /* adjoints ...@@ -125,3 +125,10 @@ void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& /* adjoints
{ {
throw ngraph_error("Forward-propagation-only operation"); throw ngraph_error("Forward-propagation-only operation");
} }
bool op::util::IndexReduction::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("index_element_type", m_index_element_type);
return true;
}
...@@ -35,27 +35,28 @@ namespace ngraph ...@@ -35,27 +35,28 @@ namespace ngraph
IndexReduction(); IndexReduction();
IndexReduction(const Output<Node>& arg, IndexReduction(const Output<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
IndexReduction(const std::shared_ptr<Node>& arg, IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
IndexReduction(const std::string& node_type, IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
public: public:
size_t get_reduction_axis() const; uint64_t get_reduction_axis() const;
void set_reduction_axis(size_t value); void set_reduction_axis(uint64_t value);
element::Type get_index_element_type() const; element::Type get_index_element_type() const;
void set_index_element_type(const element::Type& index_element_type); void set_index_element_type(const element::Type& index_element_type);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
protected: protected:
size_t m_axis{0}; uint64_t m_axis{0};
element::Type m_index_element_type; element::Type m_index_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <cstdint>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
......
...@@ -46,6 +46,7 @@ set(SRC ...@@ -46,6 +46,7 @@ set(SRC
aligned_buffer.cpp aligned_buffer.cpp
all_close_f.cpp all_close_f.cpp
assertion.cpp assertion.cpp
attributes.cpp
bfloat16.cpp bfloat16.cpp
build_graph.cpp build_graph.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
enum class TuringModel
{
XL400,
XL1200
};
namespace ngraph
{
template <>
EnumNames<TuringModel>& EnumNames<TuringModel>::get()
{
static auto enum_names = EnumNames<TuringModel>(
"TuringModel", {{"XL400", TuringModel::XL400}, {"XL1200", TuringModel::XL1200}});
return enum_names;
}
template <>
class AttributeAdapter<TuringModel> : public EnumAttributeAdapterBase<TuringModel>
{
public:
AttributeAdapter(TuringModel& value)
: EnumAttributeAdapterBase<TuringModel>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<TuringModel>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
constexpr DiscreteTypeInfo AttributeAdapter<TuringModel>::type_info;
}
// Given a Turing machine program and data, return scalar 1 if the program would
// complete, 1 if it would not.
class Oracle : public op::Op
{
public:
Oracle(const Output<Node>& program,
const Output<Node>& data,
TuringModel turing_model,
uint64_t model_version,
uint8_t rev,
const string& serial_number,
bool enable_turbo,
const std::vector<uint64_t>& hyper_parameters,
const std::vector<int64_t>& ultra_parameters)
: Op({program, data})
, m_turing_model(turing_model)
, m_model_version(model_version)
, m_rev(rev)
, m_serial_number(serial_number)
, m_enable_turbo(enable_turbo)
, m_hyper_parameters(hyper_parameters)
, m_ultra_parameters(ultra_parameters)
{
}
static constexpr NodeTypeInfo type_info{"Oracle", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Oracle() = default;
TuringModel get_turing_model() const { return m_turing_model; }
uint64_t get_model_version() const { return m_model_version; }
const string& get_serial_number() const { return m_serial_number; }
bool get_enable_turbo() const { return m_enable_turbo; }
const vector<uint64_t>& get_hyper_parameters() const { return m_hyper_parameters; }
const vector<int64_t>& get_ultra_parameters() const { return m_ultra_parameters; }
shared_ptr<Node> copy_with_new_args(const NodeVector& args) const override
{
return make_shared<Oracle>(args[0],
args[1],
m_turing_model,
m_model_version,
m_rev,
m_serial_number,
m_enable_turbo,
m_hyper_parameters,
m_ultra_parameters);
}
void validate_and_infer_types() override { set_output_type(0, element::i64, {}); }
bool visit_attributes(AttributeVisitor& visitor) override
{
visitor.on_attribute("turing_model", m_turing_model);
visitor.on_attribute("model_version", m_model_version);
visitor.on_attribute("rev", m_rev);
visitor.on_attribute("serial_number", m_serial_number);
visitor.on_attribute("enable_turbo", m_enable_turbo);
visitor.on_attribute("hyper_parameters", m_hyper_parameters);
visitor.on_attribute("ultra_parameters", m_ultra_parameters);
return true;
}
protected:
TuringModel m_turing_model;
uint64_t m_model_version;
int8_t m_rev;
string m_serial_number;
bool m_enable_turbo;
vector<uint64_t> m_hyper_parameters;
vector<int64_t> m_ultra_parameters;
};
constexpr NodeTypeInfo Oracle::type_info;
class NodeSaver : public AttributeVisitor
{
public:
NodeSaver(shared_ptr<Node> node)
: m_node_type_info(node->get_type_info())
{
node->visit_attributes(*this);
}
const NodeTypeInfo& get_node_type_info() { return m_node_type_info; }
string& get_string(const string& name) { return m_strings.at(name); }
bool get_bool(const string& name) { return m_bools.at(name); }
double get_double(const string& name) { return m_doubles.at(name); }
int64_t get_signed(const string& name) { return m_signeds.at(name); }
uint64_t get_unsigned(const string& name) { return m_unsigneds.at(name); }
vector<int64_t>& get_signed_vector(const string& name) { return m_signed_vectors.at(name); }
void set_string(const string& name, const string& value) { m_strings[name] = value; }
void set_bool(const string& name, bool value) { m_bools[name] = value; }
void set_double(const string& name, double value) { m_doubles[name] = value; }
void set_signed(const string& name, int64_t value) { m_signeds[name] = value; }
void set_unsigned(const string& name, uint64_t value) { m_unsigneds[name] = value; }
void set_signed_vector(const string& name, const vector<int64_t>& value)
{
m_signed_vectors[name] = value;
}
void on_attribute(const string& name, string& value) override { set_string(name, value); };
void on_attribute(const string& name, bool& value) override { set_bool(name, value); }
void on_adapter(const string& name, ValueAccessor<void>& adapter) override
{
NGRAPH_CHECK(false, "name cannot be marshalled");
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const string& name, ValueAccessor<string>& adapter) override
{
set_string(name, adapter.get());
};
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
set_signed_vector(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override
{
set_signed(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<double>& adapter) override
{
set_double(name, adapter.get());
}
protected:
NodeTypeInfo m_node_type_info;
map<string, string> m_strings;
map<string, bool> m_bools;
map<string, double> m_doubles;
map<string, int64_t> m_signeds;
map<string, uint64_t> m_unsigneds;
map<string, vector<int64_t>> m_signed_vectors;
};
class NodeBuilder : public AttributeVisitor
{
public:
NodeBuilder(const shared_ptr<Node>& node)
: m_values(node)
{
}
shared_ptr<Node> create()
{
shared_ptr<Node> node(FactoryRegistry<Node>::get().create(m_values.get_node_type_info()));
node->visit_attributes(*this);
node->validate_and_infer_types();
return node;
}
void on_attribute(const string& name, string& value) override
{
value = m_values.get_string(name);
};
void on_attribute(const string& name, bool& value) override { value = m_values.get_bool(name); }
void on_adapter(const string& name, ValueAccessor<void>& adapter) override
{
NGRAPH_CHECK(false, "name cannot be marshalled");
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const string& name, ValueAccessor<string>& adapter) override
{
adapter.set(m_values.get_string(name));
};
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
adapter.set(m_values.get_signed_vector(name));
}
void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override
{
adapter.set(m_values.get_signed(name));
}
void on_adapter(const string& name, ValueAccessor<double>& adapter) override
{
adapter.set(m_values.get_double(name));
}
protected:
NodeSaver m_values;
};
TEST(attributes, user_op)
{
FactoryRegistry<Node>::get().register_factory<Oracle>();
auto program = make_shared<op::Parameter>(element::i32, Shape{200});
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto oracle = make_shared<Oracle>(program,
data,
TuringModel::XL1200,
2,
4,
"12AU7",
true,
vector<uint64_t>{1, 2, 4, 8},
vector<int64_t>{-1, -2, -4, -8});
NodeBuilder builder(oracle);
auto g_oracle = as_type_ptr<Oracle>(builder.create());
EXPECT_EQ(g_oracle->get_turing_model(), oracle->get_turing_model());
EXPECT_EQ(g_oracle->get_model_version(), oracle->get_model_version());
EXPECT_EQ(g_oracle->get_serial_number(), oracle->get_serial_number());
EXPECT_EQ(g_oracle->get_enable_turbo(), oracle->get_enable_turbo());
EXPECT_EQ(g_oracle->get_hyper_parameters(), oracle->get_hyper_parameters());
EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment