Unverified Commit f73cfcf0 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Allow for opset extensions (#4391)

* Allow for opset extensions

* Remove raw pointer from OpSet (#4395)

* style
Co-authored-by: 's avatarIlya Churaev <ilyachur@gmail.com>
parent ed30c615
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
} }
/// \brief Create an instance for type_info /// \brief Create an instance for type_info
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info); auto it = m_factory_map.find(type_info);
...@@ -86,7 +86,7 @@ namespace ngraph ...@@ -86,7 +86,7 @@ namespace ngraph
/// \brief Create an instance using factory for DERIVED_TYPE /// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
BASE_TYPE* create() BASE_TYPE* create() const
{ {
return create(DERIVED_TYPE::type_info); return create(DERIVED_TYPE::type_info);
} }
......
...@@ -28,7 +28,7 @@ ngraph::Node* ngraph::OpSet::create(const std::string& name) const ...@@ -28,7 +28,7 @@ ngraph::Node* ngraph::OpSet::create(const std::string& name) const
auto type_info_it = m_name_type_info_map.find(name); auto type_info_it = m_name_type_info_map.find(name);
return type_info_it == m_name_type_info_map.end() return type_info_it == m_name_type_info_map.end()
? nullptr ? nullptr
: FactoryRegistry<Node>::get().create(type_info_it->second); : m_factory_registry.create(type_info_it->second);
} }
const ngraph::OpSet& ngraph::get_opset0() const ngraph::OpSet& ngraph::get_opset0()
......
...@@ -32,6 +32,7 @@ namespace ngraph ...@@ -32,6 +32,7 @@ namespace ngraph
static std::mutex& get_mutex(); static std::mutex& get_mutex();
public: public:
OpSet() {}
std::set<NodeTypeInfo>::size_type size() const std::set<NodeTypeInfo>::size_type size() const
{ {
std::lock_guard<std::mutex> guard(get_mutex()); std::lock_guard<std::mutex> guard(get_mutex());
...@@ -45,7 +46,7 @@ namespace ngraph ...@@ -45,7 +46,7 @@ namespace ngraph
std::lock_guard<std::mutex> guard(get_mutex()); std::lock_guard<std::mutex> guard(get_mutex());
m_op_types.insert(type_info); m_op_types.insert(type_info);
m_name_type_info_map[name] = type_info; m_name_type_info_map[name] = type_info;
ngraph::FactoryRegistry<Node>::get().register_factory(type_info, factory); m_factory_registry.register_factory(type_info, factory);
} }
/// \brief Insert OP_TYPE into the opset with a special name and the default factory /// \brief Insert OP_TYPE into the opset with a special name and the default factory
...@@ -94,7 +95,9 @@ namespace ngraph ...@@ -94,7 +95,9 @@ namespace ngraph
return m_op_types.find(node->get_type_info()) != m_op_types.end(); return m_op_types.find(node->get_type_info()) != m_op_types.end();
} }
ngraph::FactoryRegistry<ngraph::Node>& get_factory_registry() { return m_factory_registry; }
protected: protected:
ngraph::FactoryRegistry<ngraph::Node> m_factory_registry;
std::set<NodeTypeInfo> m_op_types; std::set<NodeTypeInfo> m_op_types;
std::map<std::string, NodeTypeInfo> m_name_type_info_map; std::map<std::string, NodeTypeInfo> m_name_type_info_map;
}; };
...@@ -102,4 +105,4 @@ namespace ngraph ...@@ -102,4 +105,4 @@ namespace ngraph
const NGRAPH_API OpSet& get_opset0(); const NGRAPH_API OpSet& get_opset0();
const NGRAPH_API OpSet& get_opset1(); const NGRAPH_API OpSet& get_opset1();
const NGRAPH_API OpSet& get_opset2(); const NGRAPH_API OpSet& get_opset2();
} }
\ No newline at end of file
...@@ -188,6 +188,8 @@ TEST(opset, new_op) ...@@ -188,6 +188,8 @@ TEST(opset, new_op)
// Fred should be in the copy // Fred should be in the copy
fred = shared_ptr<Node>(opset1_copy.create("Fred")); fred = shared_ptr<Node>(opset1_copy.create("Fred"));
EXPECT_TRUE(fred); EXPECT_TRUE(fred);
// Fred should not be in the registry
ASSERT_FALSE(FactoryRegistry<Node>::get().has_factory<NewOp>());
} }
TEST(opset, dump) TEST(opset, dump)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment