Unverified Commit 8adf78fe authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Runtime-extensible opsets (#3991)

* Fix some opset bugs
Typo in opset0
Include ops.hpp rather than ngraph.hpp

* Opset op insertion

* Make opsets extendable, able to create instances

* Update like replacement

* Review comments
parent 83dcc092
......@@ -19,42 +19,82 @@
#include <functional>
#include <map>
#include <mutex>
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
NGRAPH_API std::mutex& get_registry_mutex();
template <typename T>
/// \brief Registry of factories that can construct objects derived from BASE_TYPE
template <typename BASE_TYPE>
class FactoryRegistry
{
public:
using Factory = std::function<T*()>;
using FactoryMap = std::map<decltype(T::type_info), Factory>;
using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::map<decltype(BASE_TYPE::type_info), Factory>;
template <typename U>
void register_factory()
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
template <typename DERIVED_TYPE>
static Factory get_default_factory()
{
return []() { return new DERIVED_TYPE(); };
}
/// \brief Register a custom factory for type_info
void register_factory(const decltype(BASE_TYPE::type_info) & type_info, Factory factory)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[U::type_info] = []() { return new U(); };
m_factory_map[type_info] = factory;
}
bool has_factory(const decltype(T::type_info) & info)
/// \brief Register a custom factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory(Factory factory)
{
register_factory(DERIVED_TYPE::type_info, factory);
}
/// \brief Register the defualt constructor factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
void register_factory()
{
register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
}
/// \brief Check to see if a factory is registered
bool has_factory(const decltype(BASE_TYPE::type_info) & info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end();
}
T* create(const decltype(T::type_info) & info)
/// \brief Check to see if DERIVED_TYPE has a registered factory
template <typename DERIVED_TYPE>
bool has_factory()
{
return has_factory(DERIVED_TYPE::type_info);
}
/// \brief Create an instance for type_info
BASE_TYPE* create(const decltype(BASE_TYPE::type_info) & type_info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(info);
auto it = m_factory_map.find(type_info);
return it == m_factory_map.end() ? nullptr : it->second();
}
static FactoryRegistry<T>& get();
/// \brief Create an instance using factory for DERIVED_TYPE
template <typename DERIVED_TYPE>
BASE_TYPE* create()
{
return create(DERIVED_TYPE::type_info);
}
/// \brief Get the factory for BASE_TYPE
static FactoryRegistry<BASE_TYPE>& get();
protected:
// Need a Compare on type_info
FactoryMap m_factory_map;
};
}
......@@ -337,7 +337,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
return rc;
}
constexpr NodeTypeInfo op::ScalarConstantLikeBase::type_info;
constexpr NodeTypeInfo op::ScalarConstantLike::type_info;
shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
{
......
......@@ -359,8 +359,6 @@ namespace ngraph
class NGRAPH_API ScalarConstantLikeBase : public Constant
{
public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLikeBase", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
std::shared_ptr<op::Constant> as_constant() const;
ScalarConstantLikeBase() = default;
......@@ -375,6 +373,8 @@ namespace ngraph
class NGRAPH_API ScalarConstantLike : public ScalarConstantLikeBase
{
public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLike", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief A scalar constant whose element type is the same as like.
///
/// Once the element type is known, the dependency on like will be removed and
......@@ -390,6 +390,8 @@ namespace ngraph
constructor_validate_and_infer_types();
}
ScalarConstantLike() = default;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
......
......@@ -52,7 +52,7 @@ namespace ngraph
class NGRAPH_API GreaterEqual : public util::BinaryElementwiseComparison
{
public:
static constexpr NodeTypeInfo type_info{"GreaterEq", 1};
static constexpr NodeTypeInfo type_info{"GreaterEqual", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEqual() = default;
......
......@@ -206,7 +206,7 @@ NGRAPH_OP(Result, ngraph::op, 0)
NGRAPH_OP(Reverse, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v1, 1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
NGRAPH_OP(ScalarConstantLikeBase, ngraph::op, 0)
NGRAPH_OP(ScalarConstantLike, ngraph::op, 0)
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
NGRAPH_OP(ScatterNDAdd, ngraph::op::v0, 0)
......
......@@ -17,22 +17,50 @@
#include "ngraph/opsets/opset.hpp"
#include "ngraph/ops.hpp"
std::mutex& ngraph::OpSet::get_mutex()
{
static std::mutex opset_mutex;
return opset_mutex;
}
ngraph::Node* ngraph::OpSet::create(const std::string& name) const
{
auto type_info_it = m_name_type_info_map.find(name);
return type_info_it == m_name_type_info_map.end()
? nullptr
: FactoryRegistry<Node>::get().create(type_info_it->second);
}
const ngraph::OpSet& ngraph::get_opset0()
{
static OpSet opset({
#define NGRAPH_OP(NAME, NAMESPACE) NAMESPACE::NAME::type_info,
static std::mutex init_mutex;
static OpSet opset;
if (opset.size() == 0)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (opset.size() == 0)
{
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset0_tbl.hpp"
#undef NGRAPH_OP
});
}
}
return opset;
}
const ngraph::OpSet& ngraph::get_opset1()
{
static OpSet opset({
#define NGRAPH_OP(NAME, NAMESPACE) NAMESPACE::NAME::type_info,
static std::mutex init_mutex;
static OpSet opset;
if (opset.size() == 0)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (opset.size() == 0)
{
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset1_tbl.hpp"
#undef NGRAPH_OP
});
}
}
return opset;
}
\ No newline at end of file
}
......@@ -16,33 +16,86 @@
#pragma once
#include <map>
#include <mutex>
#include <set>
#include "ngraph/factory.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
class OpSet
/// \brief Run-time opset information
class NGRAPH_API OpSet
{
static std::mutex& get_mutex();
public:
OpSet(const std::set<NodeTypeInfo>& op_types)
: m_op_types(op_types)
std::set<NodeTypeInfo>::size_type size() const
{
std::lock_guard<std::mutex> guard(get_mutex());
return m_op_types.size();
}
/// \brief Insert an op into the opset with a particular name and factory
void insert(const std::string& name,
const NodeTypeInfo& type_info,
FactoryRegistry<Node>::Factory factory)
{
std::lock_guard<std::mutex> guard(get_mutex());
m_op_types.insert(type_info);
m_name_type_info_map[name] = type_info;
ngraph::FactoryRegistry<Node>::get().register_factory(type_info, factory);
}
/// \brief Insert OP_TYPE into the opset with a special name and the default factory
template <typename OP_TYPE>
void insert(const std::string& name)
{
insert(name, OP_TYPE::type_info, FactoryRegistry<Node>::get_default_factory<OP_TYPE>());
}
/// \brief Insert OP_TYPE into the opset with the default name and factory
template <typename OP_TYPE>
void insert()
{
insert<OP_TYPE>(OP_TYPE::type_info.name);
}
/// \brief Create the op named name using it's factory
ngraph::Node* create(const std::string& name) const;
/// \brief Return true if OP_TYPE is in the opset
bool contains_type(const NodeTypeInfo& type_info) const
{
std::lock_guard<std::mutex> guard(get_mutex());
return m_op_types.find(type_info) != m_op_types.end();
}
template <typename T>
/// \brief Return true if OP_TYPE is in the opset
template <typename OP_TYPE>
bool contains_type() const
{
return m_op_types.find(T::type_info) != m_op_types.end();
return contains_type(OP_TYPE::type_info);
}
/// \brief Return true if name is in the opset
bool contains_type(const std::string& name) const
{
std::lock_guard<std::mutex> guard(get_mutex());
return m_name_type_info_map.find(name) != m_name_type_info_map.end();
}
/// \brief Return true if node's type is in the opset
bool contains_op_type(Node* node) const
{
std::lock_guard<std::mutex> guard(get_mutex());
return m_op_types.find(node->get_type_info()) != m_op_types.end();
}
protected:
std::set<NodeTypeInfo> m_op_types;
std::map<std::string, NodeTypeInfo> m_name_type_info_map;
};
const OpSet& get_opset0();
......
......@@ -33,13 +33,11 @@
using namespace std;
using namespace ngraph;
#define TI(x) type_index(typeid(x))
static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node)
{
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like"
// argument
auto broadcast_like = static_pointer_cast<op::BroadcastLike>(node);
auto broadcast_like = as_type_ptr<op::BroadcastLike>(node);
replace_node(node,
make_shared<op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(),
......@@ -47,18 +45,28 @@ static bool replace_broadcast_like(const std::shared_ptr<ngraph::Node>& node)
return true;
}
static const unordered_map<type_index, function<bool(const shared_ptr<Node>&)>> dispatcher{
{TI(op::BroadcastLike), &replace_broadcast_like}};
static bool replace_scalar_constant_like(const std::shared_ptr<Node>& node)
{
auto scalar_constant_like = as_type_ptr<op::ScalarConstantLike>(node);
replace_node(node, scalar_constant_like->as_constant());
return true;
}
bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function)
static const map<NodeTypeInfo, function<bool(const shared_ptr<Node>&)>> dispatcher{
{op::BroadcastLike::type_info, replace_broadcast_like},
{op::ScalarConstantLike::type_info, replace_scalar_constant_like}};
bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function_ptr)
{
bool clobbered = false;
static const map<NodeTypeInfo, function<bool(const shared_ptr<Node>&)>> dispatcher{
{op::BroadcastLike::type_info, replace_broadcast_like},
{op::ScalarConstantLike::type_info, replace_scalar_constant_like}};
for (const auto& n : function->get_ops())
bool clobbered = false;
for (const auto& n : function_ptr->get_ops())
{
// Work around a warning [-Wpotentially-evaluated-expression]
const Node& node = *n;
auto handler = dispatcher.find(TI(node));
auto handler = dispatcher.find(n->get_type_info());
if (handler != dispatcher.end())
{
clobbered = handler->second(n) || clobbered;
......
......@@ -549,7 +549,6 @@ private:
reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ScalarConstantLike: break;
case OP_TYPEID::Convert:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
......@@ -1879,6 +1878,7 @@ private:
case OP_TYPEID::PartialSlice:
case OP_TYPEID::PartialSliceBackprop:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:
......
......@@ -2587,7 +2587,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::ReorgYolo: { break;
}
case OP_TYPEID::ScalarConstantLikeBase:
case OP_TYPEID::ScalarConstantLike:
{
double value = node_js.at("value").get<double>();
node = make_shared<op::ScalarConstantLike>(args[0], value);
......@@ -4329,7 +4329,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["activations_beta"] = tmp->get_activations_beta();
break;
}
case OP_TYPEID::ScalarConstantLikeBase:
case OP_TYPEID::ScalarConstantLike:
{
auto tmp = static_cast<const op::ScalarConstantLikeBase*>(&n);
auto constant = tmp->as_constant();
......
......@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset.hpp"
#include "ngraph/opsets/opset1.hpp"
#include <memory>
......@@ -27,7 +28,13 @@ using namespace ngraph;
#define CHECK_OPSET(op1, op2) \
EXPECT_TRUE(is_type<op1>(make_shared<op2>())); \
EXPECT_TRUE((std::is_same<op1, op2>::value));
EXPECT_TRUE((std::is_same<op1, op2>::value)); \
EXPECT_TRUE((get_opset1().contains_type<op2>())); \
{ \
shared_ptr<Node> op(get_opset1().create(op2::type_info.name)); \
ASSERT_TRUE(op); \
EXPECT_TRUE(is_type<op2>(op)); \
}
TEST(opset, check_opset1)
{
......@@ -143,3 +150,42 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v1::VariadicSplit, opset1::VariadicSplit)
CHECK_OPSET(op::v0::Xor, opset1::Xor)
}
class NewOp : public op::Op
{
public:
NewOp() = default;
static constexpr NodeTypeInfo type_info{"NewOp", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
void validate_and_infer_types() override{};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& /* new_args */) const override
{
return make_shared<NewOp>();
};
};
constexpr NodeTypeInfo NewOp::type_info;
TEST(opset, new_op)
{
// Copy opset1; don't bash the real thing in a test
OpSet opset1_copy(get_opset1());
opset1_copy.insert<NewOp>();
{
shared_ptr<Node> op(opset1_copy.create(NewOp::type_info.name));
ASSERT_TRUE(op);
EXPECT_TRUE(is_type<NewOp>(op));
}
shared_ptr<Node> fred;
fred = shared_ptr<Node>(opset1_copy.create("Fred"));
EXPECT_FALSE(fred);
opset1_copy.insert<NewOp>("Fred");
// Make sure we copied
fred = shared_ptr<Node>(get_opset1().create("Fred"));
ASSERT_FALSE(fred);
// Fred should be in the copy
fred = shared_ptr<Node>(opset1_copy.create("Fred"));
EXPECT_TRUE(fred);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment