Unverified Commit d81d0c93 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Interpreter use switch() for main loop (#1538)

* wip

* interperter use switch instead of if/else

* more cleanup

* make nop elimination run on all backends

* revert

* use single include file to define all ops so there is only one instance

* move op.tbl to ngraph/op dir as it is useful. Added useage example.

* add some comments where needed

* revert some changes to reduce delta

* add const

* add more const

* simplify using NodeWrapper

* update per review comments

* update per review comments

* update per review comments

* remove switch warning as it is not supported in older gcc
parent 5032f343
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each op. If an op is added it must be
// added to this list.
//
// In order to use this list you want to define a macro named exactly NGRAPH_OP
// When you are done you should undef the macro
// As an example if you wanted to make a list of all op names as strings you could do this:
//
// #define NGRAPH_OP(a) #a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// "Abs",
// "Acos",
// ...
//
// It's that easy. You can use this for fun and profit.
NGRAPH_OP(Abs)
NGRAPH_OP(Acos)
NGRAPH_OP(Add)
NGRAPH_OP(AllReduce)
NGRAPH_OP(And)
NGRAPH_OP(ArgMax)
NGRAPH_OP(ArgMin)
NGRAPH_OP(Asin)
NGRAPH_OP(Atan)
NGRAPH_OP(AvgPool)
NGRAPH_OP(AvgPoolBackprop)
NGRAPH_OP(BatchNorm)
NGRAPH_OP(BatchNormBackprop)
NGRAPH_OP(Broadcast)
NGRAPH_OP(Ceiling)
NGRAPH_OP(Concat)
NGRAPH_OP(Constant)
NGRAPH_OP(Convert)
NGRAPH_OP(Convolution)
NGRAPH_OP(ConvolutionBackpropData)
NGRAPH_OP(ConvolutionBackpropFilters)
NGRAPH_OP(Cos)
NGRAPH_OP(Cosh)
NGRAPH_OP(Divide)
NGRAPH_OP(Dot)
NGRAPH_OP(Equal)
NGRAPH_OP(Exp)
NGRAPH_OP(Floor)
NGRAPH_OP(FunctionCall)
NGRAPH_OP(GetOutputElement)
NGRAPH_OP(Greater)
NGRAPH_OP(GreaterEq)
NGRAPH_OP(Less)
NGRAPH_OP(LessEq)
NGRAPH_OP(Log)
NGRAPH_OP(LRN)
NGRAPH_OP(Max)
NGRAPH_OP(Maximum)
NGRAPH_OP(MaxPool)
NGRAPH_OP(MaxPoolBackprop)
NGRAPH_OP(Min)
NGRAPH_OP(Minimum)
NGRAPH_OP(Multiply)
NGRAPH_OP(Negative)
NGRAPH_OP(Not)
NGRAPH_OP(NotEqual)
NGRAPH_OP(OneHot)
NGRAPH_OP(Or)
NGRAPH_OP(Pad)
NGRAPH_OP(Parameter)
NGRAPH_OP(Power)
NGRAPH_OP(Product)
NGRAPH_OP(Reduce)
NGRAPH_OP(ReduceWindow)
NGRAPH_OP(Relu)
NGRAPH_OP(ReluBackprop)
NGRAPH_OP(ReplaceSlice)
NGRAPH_OP(Reshape)
NGRAPH_OP(Result)
NGRAPH_OP(Reverse)
NGRAPH_OP(ReverseSequence)
NGRAPH_OP(Select)
NGRAPH_OP(SelectAndScatter)
NGRAPH_OP(Sigmoid)
NGRAPH_OP(SigmoidBackprop)
NGRAPH_OP(Sign)
NGRAPH_OP(Sin)
NGRAPH_OP(Sinh)
NGRAPH_OP(Slice)
NGRAPH_OP(Softmax)
NGRAPH_OP(Sqrt)
NGRAPH_OP(StopGradient)
NGRAPH_OP(Subtract)
NGRAPH_OP(Sum)
NGRAPH_OP(Tan)
NGRAPH_OP(Tanh)
NGRAPH_OP(TopK)
...@@ -175,28 +175,7 @@ map<string, string> runtime::BackendManager::get_registered_device_map() ...@@ -175,28 +175,7 @@ map<string, string> runtime::BackendManager::get_registered_device_map()
string backend_name; string backend_name;
if (is_backend_name(name, backend_name)) if (is_backend_name(name, backend_name))
{ {
DL_HANDLE handle; rc.insert({to_upper(backend_name), file});
#ifdef WIN32
handle = LoadLibrary(file.c_str());
#else
handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
#endif
if (handle)
{
if (DLSYM(handle, "new_backend") && DLSYM(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
DLSYM(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{
rc.insert({to_upper(backend_name), file});
}
}
CLOSE_LIBRARY(handle);
}
} }
}; };
file_util::iterate_files(my_directory, f, false, true); file_util::iterate_files(my_directory, f, false, true);
......
...@@ -23,8 +23,8 @@ if(NGRAPH_DISTRIBUTED_ENABLE) ...@@ -23,8 +23,8 @@ if(NGRAPH_DISTRIBUTED_ENABLE)
endif() endif()
if (NGRAPH_INTERPRETER_ENABLE) if (NGRAPH_INTERPRETER_ENABLE)
add_library(interpreter_backend SHARED int_backend.cpp) add_library(interpreter_backend SHARED int_backend.cpp node_wrapper.cpp)
set_target_properties(interpreter_backend PROPERTIES VERSION ${NGRAPH_VERSION} SOVERSION ${NGRAPH_API_VERSION}) set_target_properties(interpreter_backend PROPERTIES VERSION ${NGRAPH_VERSION})
target_link_libraries(interpreter_backend PUBLIC ngraph) target_link_libraries(interpreter_backend PUBLIC ngraph)
set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/runtime/interpreter/int_backend.hpp" #include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
...@@ -68,6 +69,11 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function) ...@@ -68,6 +69,11 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function)
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function); pass_manager.run_passes(function);
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
instance.m_wrapped_nodes.emplace_back(node);
}
} }
return true; return true;
...@@ -125,13 +131,14 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -125,13 +131,14 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
} }
// for each ordered op in the graph // for each ordered op in the graph
for (shared_ptr<Node> op : function->get_ordered_ops()) for (const NodeWrapper& wrapped : instance.m_wrapped_nodes)
{ {
const Node* op = &wrapped.get_node();
auto type_id = wrapped.get_typeid();
if (op->description() == "Parameter") if (op->description() == "Parameter")
{ {
continue; continue;
} }
// get op inputs from map // get op inputs from map
vector<shared_ptr<runtime::HostTensorView>> op_inputs; vector<shared_ptr<runtime::HostTensorView>> op_inputs;
for (const descriptor::Input& input : op->get_inputs()) for (const descriptor::Input& input : op->get_inputs())
...@@ -164,35 +171,37 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -164,35 +171,37 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
// get op type // get op type
element::Type type; element::Type type;
if (dynamic_pointer_cast<op::util::BinaryElementwiseComparison>(op) || switch (type_id)
dynamic_pointer_cast<op::Select>(op))
{ {
case OP_TYPEID::Convert:
type = op->get_inputs().at(0).get_tensor().get_element_type();
break;
case OP_TYPEID::Equal:
case OP_TYPEID::Greater:
case OP_TYPEID::GreaterEq:
case OP_TYPEID::Less:
case OP_TYPEID::LessEq:
case OP_TYPEID::NotEqual:
// Get the type of the second input, not the first // Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs // All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second // Select has bool for first input and the type we are interested in for the second
type = op->get_inputs().at(1).get_tensor().get_element_type(); type = op->get_inputs().at(1).get_tensor().get_element_type();
} break;
else if (dynamic_pointer_cast<op::Convert>(op)) default: type = op->get_outputs().at(0).get_element_type(); break;
{
type = op->get_inputs().at(0).get_tensor().get_element_type();
}
else
{
type = op->get_outputs().at(0).get_element_type();
} }
if (instance.m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
instance.m_timer_map[op.get()].start(); instance.m_timer_map[op].start();
} }
generate_calls(type, *op, op_outputs, op_inputs); generate_calls(type, wrapped, op_outputs, op_inputs);
if (instance.m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
instance.m_timer_map[op.get()].stop(); instance.m_timer_map[op].stop();
} }
if (instance.m_nan_check_enabled) if (instance.m_nan_check_enabled)
{ {
perform_nan_check(op_outputs, op.get()); perform_nan_check(op_outputs, op);
} }
// delete any obsolete tensors // delete any obsolete tensors
...@@ -214,7 +223,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -214,7 +223,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls( void runtime::interpreter::INTBackend::generate_calls(
const element::Type& type, const element::Type& type,
Node& op, const NodeWrapper& op,
const vector<shared_ptr<HostTensorView>>& outputs, const vector<shared_ptr<HostTensorView>>& outputs,
const vector<shared_ptr<HostTensorView>>& inputs) const vector<shared_ptr<HostTensorView>>& inputs)
{ {
...@@ -265,7 +274,7 @@ void runtime::interpreter::INTBackend::generate_calls( ...@@ -265,7 +274,7 @@ void runtime::interpreter::INTBackend::generate_calls(
else else
{ {
stringstream ss; stringstream ss;
ss << "unsupported element type " << type << " op " << op.get_name(); ss << "unsupported element type " << type << " op " << op.get_node().get_name();
throw ngraph_error(ss.str()); throw ngraph_error(ss.str());
} }
} }
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::interpreter::OP_TYPEID::Abs},
// {"Acos", runtime::interpreter::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
// TODO: use unsupported_op when that is merged to master
throw runtime_error(m_node->description());
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a) a,
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::interpreter::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; }
ngraph::runtime::interpreter::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment