Unverified Commit d52473c8 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Change GPU backend to use op_tbl (#1618)

* sort op list

* use op_tbl

* throw unsupported_op exception when appropriate

* remove dead code

* Add more use of NGRAPH_OP macro to remove boilerplate definitions/implementations

* revert moving class out of namespace

* change from switch to dispatcher map
parent 3d69bf7a
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -24,12 +24,6 @@ ...@@ -24,12 +24,6 @@
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp"
#define EMITTER_DECL(op_name) \
emit<op_name>(GPU_ExternalFunction * external_function, \
codegen::CodeWriter & writer, \
const ngraph::Node* node, \
const std::vector<GPU_TensorViewWrapper>& args, \
const std::vector<GPU_TensorViewWrapper>& out)
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -39,31 +33,17 @@ namespace ngraph ...@@ -39,31 +33,17 @@ namespace ngraph
class GPU_Emitter class GPU_Emitter
{ {
public: public:
template <typename OP> static std::function<void(EMIT_ARGS)> get_emit_function(const Node& node);
static void emit(GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out)
{
throw std::runtime_error("Unimplemented op '" + node->description() +
"' in GPU emitter");
}
static void nop(GPU_ExternalFunction* external_function, // This defines a collection of function declarations like this
codegen::CodeWriter& writer, // static void emit_Abs(EMIT_ARGS);
const ngraph::Node* node, // static void emit_Acos(EMIT_ARGS);
const std::vector<GPU_TensorViewWrapper>& args, #define NGRAPH_OP(a) static void emit_##a(EMIT_ARGS);
const std::vector<GPU_TensorViewWrapper>& out) #include "ngraph/op/op_tbl.hpp"
{ #undef NGRAPH_OP
}
template <typename T> template <typename T>
static void emit_elementwise(GPU_ExternalFunction* external_function, static void emit_elementwise(EMIT_ARGS)
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPU_TensorViewWrapper>& args,
const std::vector<GPU_TensorViewWrapper>& out)
{ {
if (out[0].get_size() == 0) if (out[0].get_size() == 0)
{ {
...@@ -104,6 +84,7 @@ namespace ngraph ...@@ -104,6 +84,7 @@ namespace ngraph
static std::string node_names(const std::vector<GPU_TensorViewWrapper>& args, static std::string node_names(const std::vector<GPU_TensorViewWrapper>& args,
std::initializer_list<int> arg_indexes = {}); std::initializer_list<int> arg_indexes = {});
}; };
Shape get_padded_shape(const Shape& input_shape, Shape get_padded_shape(const Shape& input_shape,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
......
...@@ -36,6 +36,11 @@ ...@@ -36,6 +36,11 @@
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp" #include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp"
#define EMIT_ARGS \
runtime::gpu::GPU_ExternalFunction *external_function, codegen::CodeWriter &writer, \
const Node *node, const std::vector<runtime::gpu::GPU_TensorViewWrapper> &args, \
const std::vector<runtime::gpu::GPU_TensorViewWrapper> &out
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -46,15 +51,6 @@ namespace ngraph ...@@ -46,15 +51,6 @@ namespace ngraph
class GPU_CallFrame; class GPU_CallFrame;
struct GPURuntimeContext; struct GPURuntimeContext;
using OpFunction =
std::function<void(GPU_ExternalFunction* external_function,
codegen::CodeWriter&,
const ngraph::Node*,
const std::vector<GPU_TensorViewWrapper>& inputs,
const std::vector<GPU_TensorViewWrapper>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
class GPU_ExternalFunction : public std::enable_shared_from_this<GPU_ExternalFunction> class GPU_ExternalFunction : public std::enable_shared_from_this<GPU_ExternalFunction>
{ {
friend class GPU_CallFrame; friend class GPU_CallFrame;
...@@ -97,6 +93,7 @@ namespace ngraph ...@@ -97,6 +93,7 @@ namespace ngraph
void emit_debug_function_entry(Node* node); void emit_debug_function_entry(Node* node);
void emit_debug_function_exit(Node* node); void emit_debug_function_exit(Node* node);
void emit_temp_mem_pool_allocation(std::shared_ptr<Function> current_function); void emit_temp_mem_pool_allocation(std::shared_ptr<Function> current_function);
void emit_op(EMIT_ARGS);
void release_function() { m_function = nullptr; } void release_function() { m_function = nullptr; }
void store_emitted_functions(const std::string& code); void store_emitted_functions(const std::string& code);
std::string emit_op_as_function(const Node& node, const std::string& function_name); std::string emit_op_as_function(const Node& node, const std::string& function_name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment