Unverified Commit 8adf78fe authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Runtime-extensible opsets (#3991)

* Fix some opset bugs
Typo in opset0
Include ops.hpp rather than ngraph.hpp

* Opset op insertion

* Make opsets extendable, able to create instances

* Update like replacement

* Review comments
parent 83dcc092
...@@ -19,42 +19,82 @@ ...@@ -19,42 +19,82 @@
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex> #include <mutex>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph
{ {
NGRAPH_API std::mutex& get_registry_mutex(); NGRAPH_API std::mutex& get_registry_mutex();
template <typename T> /// \brief Registry of factories that can construct objects derived from BASE_TYPE
template <typename BASE_TYPE>
class FactoryRegistry class FactoryRegistry
{ {
public: public:
using Factory = std::function<T*()>; using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::map<decltype(T::type_info), Factory>; using FactoryMap = std::map<decltype(BASE_TYPE::type_info), Factory>;
template <typename U> // \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
void register_factory() template <typename DERIVED_TYPE>
static Factory get_default_factory()
{
return []() { return new DERIVED_TYPE(); };
}
/// \brief Register a custom factory for type_info
void register_factory(const decltype(BASE_TYPE::type_info) & type_info, Factory factory)
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[U::type_info] = []() { return new U(); }; m_factory_map[type_info] = factory;
} }
bool has_factory(const decltype(T::type_info) & info) /// \brief Register a custom factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory(Factory factory)
{
register_factory(DERIVED_TYPE::type_info, factory);
}
/// \brief Register the defualt constructor factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory()
{
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
}
/// \brief Check to see if a factory is registered
bool has_factory(const decltype(BASE_TYPE::type_info) & info)
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end(); return m_factory_map.find(info) != m_factory_map.end();
} }
T* create(const decltype(T::type_info) & info) /// \brief Check to see if DERIVED_TYPE has a registered factory
template <typename DERIVED_TYPE>
bool has_factory()
{
return has_factory(DERIVED_TYPE::type_info);
}
/// \brief Create an instance for type_info
BASE_TYPE* create(const decltype(BASE_TYPE::type_info) & type_info)
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(info); auto it = m_factory_map.find(type_info);
return it == m_factory_map.end() ? nullptr : it->second(); return it == m_factory_map.end() ? nullptr : it->second();
} }
static FactoryRegistry<T>& get();
/// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
BASE_TYPE* create()
{
return create(DERIVED_TYPE::type_info);
}
/// \brief Get the factory for BASE_TYPE
static FactoryRegistry<BASE_TYPE>& get();
protected: protected:
// Need a Compare on type_info
FactoryMap m_factory_map; FactoryMap m_factory_map;
}; };
} }
...@@ -337,7 +337,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const ...@@ -337,7 +337,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
return rc; return rc;
} }
constexpr NodeTypeInfo op::ScalarConstantLikeBase::type_info; constexpr NodeTypeInfo op::ScalarConstantLike::type_info;
shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
{ {
......
...@@ -359,8 +359,6 @@ namespace ngraph ...@@ -359,8 +359,6 @@ namespace ngraph
class NGRAPH_API ScalarConstantLikeBase : public Constant class NGRAPH_API ScalarConstantLikeBase : public Constant
{ {
public: public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLikeBase", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
std::shared_ptr<op::Constant> as_constant() const; std::shared_ptr<op::Constant> as_constant() const;
ScalarConstantLikeBase() = default; ScalarConstantLikeBase() = default;
...@@ -375,6 +373,8 @@ namespace ngraph ...@@ -375,6 +373,8 @@ namespace ngraph
class NGRAPH_API ScalarConstantLike : public ScalarConstantLikeBase class NGRAPH_API ScalarConstantLike : public ScalarConstantLikeBase
{ {
public: public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLike", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief A scalar constant whose element type is the same as like. /// \brief A scalar constant whose element type is the same as like.
/// ///
/// Once the element type is known, the dependency on like will be removed and /// Once the element type is known, the dependency on like will be removed and
...@@ -390,6 +390,8 @@ namespace ngraph ...@@ -390,6 +390,8 @@ namespace ngraph
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ScalarConstantLike() = default;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
......
...@@ -52,7 +52,7 @@ namespace ngraph ...@@ -52,7 +52,7 @@ namespace ngraph
class NGRAPH_API GreaterEqual : public util::BinaryElementwiseComparison class NGRAPH_API GreaterEqual : public util::BinaryElementwiseComparison
{ {
public: public:
static constexpr NodeTypeInfo type_info{"GreaterEq", 1}; static constexpr NodeTypeInfo type_info{"GreaterEqual", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
GreaterEqual() = default; GreaterEqual() = default;
......
...@@ -206,7 +206,7 @@ NGRAPH_OP(Result, ngraph::op, 0) ...@@ -206,7 +206,7 @@ NGRAPH_OP(Result, ngraph::op, 0)
NGRAPH_OP(Reverse, ngraph::op::v0, 0) NGRAPH_OP(Reverse, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v1, 1) NGRAPH_OP(Reverse, ngraph::op::v1, 1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0) NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
NGRAPH_OP(ScalarConstantLikeBase, ngraph::op, 0) NGRAPH_OP(ScalarConstantLike, ngraph::op, 0)
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0) NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0) NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
NGRAPH_OP(ScatterNDAdd, ngraph::op::v0, 0) NGRAPH_OP(ScatterNDAdd, ngraph::op::v0, 0)
......
...@@ -17,22 +17,50 @@ ...@@ -17,22 +17,50 @@
#include "ngraph/opsets/opset.hpp" #include "ngraph/opsets/opset.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
std::mutex& ngraph::OpSet::get_mutex()
{
static std::mutex opset_mutex;
return opset_mutex;
}
ngraph::Node* ngraph::OpSet::create(const std::string& name) const
{
auto type_info_it = m_name_type_info_map.find(name);
return type_info_it == m_name_type_info_map.end()
? nullptr
: FactoryRegistry<Node>::get().create(type_info_it->second);
}
const ngraph::OpSet& ngraph::get_opset0() const ngraph::OpSet& ngraph::get_opset0()
{ {
static OpSet opset({ static std::mutex init_mutex;
#define NGRAPH_OP(NAME, NAMESPACE) NAMESPACE::NAME::type_info, static OpSet opset;
if (opset.size() == 0)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (opset.size() == 0)
{
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset0_tbl.hpp" #include "ngraph/opsets/opset0_tbl.hpp"
#undef NGRAPH_OP #undef NGRAPH_OP
}); }
}
return opset; return opset;
} }
const ngraph::OpSet& ngraph::get_opset1() const ngraph::OpSet& ngraph::get_opset1()
{ {
static OpSet opset({ static std::mutex init_mutex;
#define NGRAPH_OP(NAME, NAMESPACE) NAMESPACE::NAME::type_info, static OpSet opset;
if (opset.size() == 0)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (opset.size() == 0)
{
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset1_tbl.hpp" #include "ngraph/opsets/opset1_tbl.hpp"
#undef NGRAPH_OP #undef NGRAPH_OP
}); }
}
return opset; return opset;
} }
\ No newline at end of file
...@@ -16,33 +16,86 @@ ...@@ -16,33 +16,86 @@
#pragma once #pragma once
#include <map>
#include <mutex>
#include <set> #include <set>
#include "ngraph/factory.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace ngraph namespace ngraph
{ {
class OpSet /// \brief Run-time opset information
class NGRAPH_API OpSet
{ {
static std::mutex& get_mutex();
public: public:
OpSet(const std::set<NodeTypeInfo>& op_types) std::set<NodeTypeInfo>::size_type size() const
: m_op_types(op_types) {
std::lock_guard<std::mutex> guard(get_mutex());
return m_op_types.size();
}
/// \brief Insert an op into the opset with a particular name and factory
void insert(const std::string& name,
const NodeTypeInfo& type_info,
FactoryRegistry<Node>::Factory factory)
{
std::lock_guard<std::mutex> guard(get_mutex());
m_op_types.insert(type_info);
m_name_type_info_map[name] = type_info;
ngraph::FactoryRegistry<Node>::get().register_factory(type_info, factory);
}
/// \brief Insert OP_TYPE into the opset with a special name and the default factory
template <typename OP_TYPE>
void insert(const std::string& name)
{
insert(name, OP_TYPE::type_info, FactoryRegistry<Node>::get_default_factory<OP_TYPE>());
}
/// \brief Insert OP_TYPE into the opset with the default name and factory
template <typename OP_TYPE>
void insert()
{
insert<OP_TYPE>(OP_TYPE::type_info.name);
}
/// \brief Create the op named name using it's factory
ngraph::Node* create(const std::string& name) const;
/// \brief Return true if OP_TYPE is in the opset
bool contains_type(const NodeTypeInfo& type_info) const
{ {
std::lock_guard<std::mutex> guard(get_mutex());
return m_op_types.find(type_info) != m_op_types.end();
} }
template <typename T> /// \brief Return true if OP_TYPE is in the opset
template <typename OP_TYPE>
bool contains_type() const bool contains_type() const
{ {
return m_op_types.find(T::type_info) != m_op_types.end(); return contains_type(OP_TYPE::type_info);
}
/// \brief Return true if name is in the opset
bool contains_type(const std::string& name) const
{
std::lock_guard<std::mutex> guard(get_mutex());
return m_name_type_info_map.find(name) != m_name_type_info_map.end();
} }
/// \brief Return true if node's type is in the opset
bool contains_op_type(Node* node) const bool contains_op_type(Node* node) const
{ {
std::lock_guard<std::mutex> guard(get_mutex());
return m_op_types.find(node->get_type_info()) != m_op_types.end(); return m_op_types.find(node->get_type_info()) != m_op_types.end();
} }
protected: protected:
std::set<NodeTypeInfo> m_op_types; std::set<NodeTypeInfo> m_op_types;
std::map<std::string, NodeTypeInfo> m_name_type_info_map;
}; };
const OpSet& get_opset0(); const OpSet& get_opset0();
......
...@@ -33,13 +33,11 @@ ...@@ -33,13 +33,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
#define TI(x) type_index(typeid(x))
static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node) static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like"
// argument // argument
auto broadcast_like = static_pointer_cast<op::BroadcastLike>(node); auto broadcast_like = as_type_ptr<op::BroadcastLike>(node);
replace_node(node, replace_node(node,
make_shared<op::Broadcast>(broadcast_like->get_argument(0), make_shared<op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(), broadcast_like->get_broadcast_shape(),
...@@ -47,18 +45,28 @@ static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node) ...@@ -47,18 +45,28 @@ static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node)
return true; return true;
} }
static const unordered_map<type_index, function<bool(const shared_ptr<Node>&)>> dispatcher{ static bool replace_scalar_constant_like(const std::shared_ptr<Node>& node)
{TI(op::BroadcastLike), &replace_broadcast_like}}; {
auto scalar_constant_like = as_type_ptr<op::ScalarConstantLike>(node);
replace_node(node, scalar_constant_like->as_constant());
return true;
}
bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function) static const map<NodeTypeInfo, function<bool(const shared_ptr<Node>&)>> dispatcher{
{op::BroadcastLike::type_info, replace_broadcast_like},
{op::ScalarConstantLike::type_info, replace_scalar_constant_like}};
bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function_ptr)
{ {
bool clobbered = false; static const map<NodeTypeInfo, function<bool(const shared_ptr<Node>&)>> dispatcher{
{op::BroadcastLike::type_info, replace_broadcast_like},
{op::ScalarConstantLike::type_info, replace_scalar_constant_like}};
for (const auto& n : function->get_ops()) bool clobbered = false;
for (const auto& n : function_ptr->get_ops())
{ {
// Work around a warning [-Wpotentially-evaluated-expression] // Work around a warning [-Wpotentially-evaluated-expression]
const Node& node = *n; auto handler = dispatcher.find(n->get_type_info());
auto handler = dispatcher.find(TI(node));
if (handler != dispatcher.end()) if (handler != dispatcher.end())
{ {
clobbered = handler->second(n) || clobbered; clobbered = handler->second(n) || clobbered;
......
...@@ -549,7 +549,6 @@ private: ...@@ -549,7 +549,6 @@ private:
reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count); reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ScalarConstantLike: break;
case OP_TYPEID::Convert: case OP_TYPEID::Convert:
{ {
// const op::Convert* c = static_cast<const op::Convert*>(&node); // const op::Convert* c = static_cast<const op::Convert*>(&node);
...@@ -1879,6 +1878,7 @@ private: ...@@ -1879,6 +1878,7 @@ private:
case OP_TYPEID::PartialSlice: case OP_TYPEID::PartialSlice:
case OP_TYPEID::PartialSliceBackprop: case OP_TYPEID::PartialSliceBackprop:
case OP_TYPEID::RNNCell: case OP_TYPEID::RNNCell:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift: case OP_TYPEID::ScaleShift:
case OP_TYPEID::Selu: case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels: case OP_TYPEID::ShuffleChannels:
......
...@@ -2587,7 +2587,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2587,7 +2587,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::ReorgYolo: { break; case OP_TYPEID::ReorgYolo: { break;
} }
case OP_TYPEID::ScalarConstantLikeBase: case OP_TYPEID::ScalarConstantLike:
{ {
double value = node_js.at("value").get<double>(); double value = node_js.at("value").get<double>();
node = make_shared<op::ScalarConstantLike>(args[0], value); node = make_shared<op::ScalarConstantLike>(args[0], value);
...@@ -4329,7 +4329,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4329,7 +4329,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["activations_beta"] = tmp->get_activations_beta(); node["activations_beta"] = tmp->get_activations_beta();
break; break;
} }
case OP_TYPEID::ScalarConstantLikeBase: case OP_TYPEID::ScalarConstantLike:
{ {
auto tmp = static_cast<const op::ScalarConstantLikeBase*>(&n); auto tmp = static_cast<const op::ScalarConstantLikeBase*>(&n);
auto constant = tmp->as_constant(); auto constant = tmp->as_constant();
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset.hpp"
#include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset1.hpp"
#include <memory> #include <memory>
...@@ -27,7 +28,13 @@ using namespace ngraph; ...@@ -27,7 +28,13 @@ using namespace ngraph;
#define CHECK_OPSET(op1, op2) \ #define CHECK_OPSET(op1, op2) \
EXPECT_TRUE(is_type<op1>(make_shared<op2>())); \ EXPECT_TRUE(is_type<op1>(make_shared<op2>())); \
EXPECT_TRUE((std::is_same<op1, op2>::value)); EXPECT_TRUE((std::is_same<op1, op2>::value)); \
EXPECT_TRUE((get_opset1().contains_type<op2>())); \
{ \
shared_ptr<Node> op(get_opset1().create(op2::type_info.name)); \
ASSERT_TRUE(op); \
EXPECT_TRUE(is_type<op2>(op)); \
}
TEST(opset, check_opset1) TEST(opset, check_opset1)
{ {
...@@ -143,3 +150,42 @@ TEST(opset, check_opset1) ...@@ -143,3 +150,42 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v1::VariadicSplit, opset1::VariadicSplit) CHECK_OPSET(op::v1::VariadicSplit, opset1::VariadicSplit)
CHECK_OPSET(op::v0::Xor, opset1::Xor) CHECK_OPSET(op::v0::Xor, opset1::Xor)
} }
class NewOp : public op::Op
{
public:
NewOp() = default;
static constexpr NodeTypeInfo type_info{"NewOp", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
void validate_and_infer_types() override{};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& /* new_args */) const override
{
return make_shared<NewOp>();
};
};
constexpr NodeTypeInfo NewOp::type_info;
TEST(opset, new_op)
{
// Copy opset1; don't bash the real thing in a test
OpSet opset1_copy(get_opset1());
opset1_copy.insert<NewOp>();
{
shared_ptr<Node> op(opset1_copy.create(NewOp::type_info.name));
ASSERT_TRUE(op);
EXPECT_TRUE(is_type<NewOp>(op));
}
shared_ptr<Node> fred;
fred = shared_ptr<Node>(opset1_copy.create("Fred"));
EXPECT_FALSE(fred);
opset1_copy.insert<NewOp>("Fred");
// Make sure we copied
fred = shared_ptr<Node>(get_opset1().create("Fred"));
ASSERT_FALSE(fred);
// Fred should be in the copy
fred = shared_ptr<Node>(opset1_copy.create("Fred"));
EXPECT_TRUE(fred);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment