Unverified Commit f73cfcf0 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Allow for opset extensions (#4391)

* Allow for opset extensions

* Remove raw pointer from OpSet (#4395)

* style
Co-authored-by: 's avatarIlya Churaev <ilyachur@gmail.com>
parent ed30c615
......@@ -77,7 +77,7 @@ namespace ngraph
}
/// \brief Create an instance for type_info
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info)
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info);
......@@ -86,7 +86,7 @@ namespace ngraph
/// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
BASE_TYPE* create()
BASE_TYPE* create() const
{
return create(DERIVED_TYPE::type_info);
}
......
......@@ -28,7 +28,7 @@ ngraph::Node* ngraph::OpSet::create(const std::string& name) const
auto type_info_it = m_name_type_info_map.find(name);
return type_info_it == m_name_type_info_map.end()
? nullptr
: FactoryRegistry<Node>::get().create(type_info_it->second);
: m_factory_registry.create(type_info_it->second);
}
const ngraph::OpSet& ngraph::get_opset0()
......
......@@ -32,6 +32,7 @@ namespace ngraph
static std::mutex& get_mutex();
public:
OpSet() {}
std::set<NodeTypeInfo>::size_type size() const
{
std::lock_guard<std::mutex> guard(get_mutex());
......@@ -45,7 +46,7 @@ namespace ngraph
std::lock_guard<std::mutex> guard(get_mutex());
m_op_types.insert(type_info);
m_name_type_info_map[name] = type_info;
ngraph::FactoryRegistry<Node>::get().register_factory(type_info, factory);
m_factory_registry.register_factory(type_info, factory);
}
/// \brief Insert OP_TYPE into the opset with a special name and the default factory
......@@ -94,7 +95,9 @@ namespace ngraph
return m_op_types.find(node->get_type_info()) != m_op_types.end();
}
ngraph::FactoryRegistry<ngraph::Node>& get_factory_registry() { return m_factory_registry; }
protected:
ngraph::FactoryRegistry<ngraph::Node> m_factory_registry;
std::set<NodeTypeInfo> m_op_types;
std::map<std::string, NodeTypeInfo> m_name_type_info_map;
};
......
......@@ -188,6 +188,8 @@ TEST(opset, new_op)
// Fred should be in the copy
fred = shared_ptr<Node>(opset1_copy.create("Fred"));
EXPECT_TRUE(fred);
// Fred should not be in the registry
ASSERT_FALSE(FactoryRegistry<Node>::get().has_factory<NewOp>());
}
TEST(opset, dump)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment