Unverified Commit d81d0c93 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Interpreter use switch() for main loop (#1538)

* wip

* interperter use switch instead of if/else

* more cleanup

* make nop elimination run on all backends

* revert

* use single include file to define all ops so there is only one instance

* move op.tbl to ngraph/op dir as it is useful. Added useage example.

* add some comments where needed

* revert some changes to reduce delta

* add const

* add more const

* simplify using NodeWrapper

* update per review comments

* update per review comments

* update per review comments

* remove switch warning as it is not supported in older gcc
parent 5032f343
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each op. If an op is added it must be
// added to this list.
//
// In order to use this list you want to define a macro named exactly NGRAPH_OP
// When you are done you should undef the macro
// As an example if you wanted to make a list of all op names as strings you could do this:
//
// #define NGRAPH_OP(a) #a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// "Abs",
// "Acos",
// ...
//
// It's that easy. You can use this for fun and profit.
NGRAPH_OP(Abs)
NGRAPH_OP(Acos)
NGRAPH_OP(Add)
NGRAPH_OP(AllReduce)
NGRAPH_OP(And)
NGRAPH_OP(ArgMax)
NGRAPH_OP(ArgMin)
NGRAPH_OP(Asin)
NGRAPH_OP(Atan)
NGRAPH_OP(AvgPool)
NGRAPH_OP(AvgPoolBackprop)
NGRAPH_OP(BatchNorm)
NGRAPH_OP(BatchNormBackprop)
NGRAPH_OP(Broadcast)
NGRAPH_OP(Ceiling)
NGRAPH_OP(Concat)
NGRAPH_OP(Constant)
NGRAPH_OP(Convert)
NGRAPH_OP(Convolution)
NGRAPH_OP(ConvolutionBackpropData)
NGRAPH_OP(ConvolutionBackpropFilters)
NGRAPH_OP(Cos)
NGRAPH_OP(Cosh)
NGRAPH_OP(Divide)
NGRAPH_OP(Dot)
NGRAPH_OP(Equal)
NGRAPH_OP(Exp)
NGRAPH_OP(Floor)
NGRAPH_OP(FunctionCall)
NGRAPH_OP(GetOutputElement)
NGRAPH_OP(Greater)
NGRAPH_OP(GreaterEq)
NGRAPH_OP(Less)
NGRAPH_OP(LessEq)
NGRAPH_OP(Log)
NGRAPH_OP(LRN)
NGRAPH_OP(Max)
NGRAPH_OP(Maximum)
NGRAPH_OP(MaxPool)
NGRAPH_OP(MaxPoolBackprop)
NGRAPH_OP(Min)
NGRAPH_OP(Minimum)
NGRAPH_OP(Multiply)
NGRAPH_OP(Negative)
NGRAPH_OP(Not)
NGRAPH_OP(NotEqual)
NGRAPH_OP(OneHot)
NGRAPH_OP(Or)
NGRAPH_OP(Pad)
NGRAPH_OP(Parameter)
NGRAPH_OP(Power)
NGRAPH_OP(Product)
NGRAPH_OP(Reduce)
NGRAPH_OP(ReduceWindow)
NGRAPH_OP(Relu)
NGRAPH_OP(ReluBackprop)
NGRAPH_OP(ReplaceSlice)
NGRAPH_OP(Reshape)
NGRAPH_OP(Result)
NGRAPH_OP(Reverse)
NGRAPH_OP(ReverseSequence)
NGRAPH_OP(Select)
NGRAPH_OP(SelectAndScatter)
NGRAPH_OP(Sigmoid)
NGRAPH_OP(SigmoidBackprop)
NGRAPH_OP(Sign)
NGRAPH_OP(Sin)
NGRAPH_OP(Sinh)
NGRAPH_OP(Slice)
NGRAPH_OP(Softmax)
NGRAPH_OP(Sqrt)
NGRAPH_OP(StopGradient)
NGRAPH_OP(Subtract)
NGRAPH_OP(Sum)
NGRAPH_OP(Tan)
NGRAPH_OP(Tanh)
NGRAPH_OP(TopK)
......@@ -175,28 +175,7 @@ map<string, string> runtime::BackendManager::get_registered_device_map()
string backend_name;
if (is_backend_name(name, backend_name))
{
DL_HANDLE handle;
#ifdef WIN32
handle = LoadLibrary(file.c_str());
#else
handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
#endif
if (handle)
{
if (DLSYM(handle, "new_backend") && DLSYM(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
DLSYM(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{
rc.insert({to_upper(backend_name), file});
}
}
CLOSE_LIBRARY(handle);
}
rc.insert({to_upper(backend_name), file});
}
};
file_util::iterate_files(my_directory, f, false, true);
......
......@@ -23,8 +23,8 @@ if(NGRAPH_DISTRIBUTED_ENABLE)
endif()
if (NGRAPH_INTERPRETER_ENABLE)
add_library(interpreter_backend SHARED int_backend.cpp)
set_target_properties(interpreter_backend PROPERTIES VERSION ${NGRAPH_VERSION} SOVERSION ${NGRAPH_API_VERSION})
add_library(interpreter_backend SHARED int_backend.cpp node_wrapper.cpp)
set_target_properties(interpreter_backend PROPERTIES VERSION ${NGRAPH_VERSION})
target_link_libraries(interpreter_backend PUBLIC ngraph)
set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
......
......@@ -16,6 +16,7 @@
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
......@@ -68,6 +69,11 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function)
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function);
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
instance.m_wrapped_nodes.emplace_back(node);
}
}
return true;
......@@ -125,13 +131,14 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
// for each ordered op in the graph
for (shared_ptr<Node> op : function->get_ordered_ops())
for (const NodeWrapper& wrapped : instance.m_wrapped_nodes)
{
const Node* op = &wrapped.get_node();
auto type_id = wrapped.get_typeid();
if (op->description() == "Parameter")
{
continue;
}
// get op inputs from map
vector<shared_ptr<runtime::HostTensorView>> op_inputs;
for (const descriptor::Input& input : op->get_inputs())
......@@ -164,35 +171,37 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
// get op type
element::Type type;
if (dynamic_pointer_cast<op::util::BinaryElementwiseComparison>(op) ||
dynamic_pointer_cast<op::Select>(op))
switch (type_id)
{
case OP_TYPEID::Convert:
type = op->get_inputs().at(0).get_tensor().get_element_type();
break;
case OP_TYPEID::Equal:
case OP_TYPEID::Greater:
case OP_TYPEID::GreaterEq:
case OP_TYPEID::Less:
case OP_TYPEID::LessEq:
case OP_TYPEID::NotEqual:
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_inputs().at(1).get_tensor().get_element_type();
}
else if (dynamic_pointer_cast<op::Convert>(op))
{
type = op->get_inputs().at(0).get_tensor().get_element_type();
}
else
{
type = op->get_outputs().at(0).get_element_type();
break;
default: type = op->get_outputs().at(0).get_element_type(); break;
}
if (instance.m_performance_counters_enabled)
{
instance.m_timer_map[op.get()].start();
instance.m_timer_map[op].start();
}
generate_calls(type, *op, op_outputs, op_inputs);
generate_calls(type, wrapped, op_outputs, op_inputs);
if (instance.m_performance_counters_enabled)
{
instance.m_timer_map[op.get()].stop();
instance.m_timer_map[op].stop();
}
if (instance.m_nan_check_enabled)
{
perform_nan_check(op_outputs, op.get());
perform_nan_check(op_outputs, op);
}
// delete any obsolete tensors
......@@ -214,7 +223,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls(
const element::Type& type,
Node& op,
const NodeWrapper& op,
const vector<shared_ptr<HostTensorView>>& outputs,
const vector<shared_ptr<HostTensorView>>& inputs)
{
......@@ -265,7 +274,7 @@ void runtime::interpreter::INTBackend::generate_calls(
else
{
stringstream ss;
ss << "unsupported element type " << type << " op " << op.get_name();
ss << "unsupported element type " << type << " op " << op.get_node().get_name();
throw ngraph_error(ss.str());
}
}
......
......@@ -21,10 +21,6 @@
#include <string>
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -49,12 +45,15 @@
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
......@@ -121,6 +120,7 @@
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
......@@ -136,6 +136,7 @@ namespace ngraph
}
}
}
class ngraph::runtime::interpreter::INTBackend : public Backend
{
public:
......@@ -165,6 +166,7 @@ private:
bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
};
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
......@@ -172,50 +174,65 @@ private:
const Node* op = nullptr);
void generate_calls(const element::Type& type,
Node& op,
const NodeWrapper& op,
const std::vector<std::shared_ptr<HostTensorView>>& outputs,
const std::vector<std::shared_ptr<HostTensorView>>& inputs);
template <typename T>
void op_engine(Node& node,
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<std::shared_ptr<HostTensorView>>& out,
const std::vector<std::shared_ptr<HostTensorView>>& args)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
if (node_op == "Abs")
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wcovered-switch-default"
switch (node_wrapper.get_typeid())
{
case OP_TYPEID::Abs:
{
reference::abs<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Acos")
case OP_TYPEID::Acos:
{
reference::acos<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Add")
case OP_TYPEID::Add:
{
reference::add<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED
else if (node_op == "AllReduce")
{
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_element_type(),
static_cast<int>(args[0]->get_element_count()));
}
#endif
else if (node_op == "And")
break;
}
case OP_TYPEID::And:
{
reference::logical_and(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "ArgMin")
case OP_TYPEID::ArgMin:
{
const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node);
if (out[0]->get_element_type() == element::i64)
......@@ -238,8 +255,9 @@ private:
{
throw ngraph_error("Unexpected type");
}
break;
}
else if (node_op == "ArgMax")
case OP_TYPEID::ArgMax:
{
const op::ArgMax* argmax = static_cast<const op::ArgMax*>(&node);
if (out[0]->get_element_type() == element::i64)
......@@ -262,20 +280,23 @@ private:
{
throw ngraph_error("Unexpected type");
}
break;
}
else if (node_op == "Asin")
case OP_TYPEID::Asin:
{
reference::asin<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Atan")
case OP_TYPEID::Atan:
{
reference::atan<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "AvgPool")
case OP_TYPEID::AvgPool:
{
op::AvgPool* avg_pool = dynamic_cast<op::AvgPool*>(&node);
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
......@@ -286,18 +307,20 @@ private:
avg_pool->get_padding_below(),
avg_pool->get_padding_above(),
avg_pool->get_include_padding_in_avg_computation());
break;
}
else if (node_op == "GetOutputElement")
case OP_TYPEID::GetOutputElement:
{
const op::GetOutputElement* get_output_element =
static_cast<const op::GetOutputElement*>(&node);
size_t n = get_output_element->get_n();
size_t num_bytes = out[0]->get_element_count() * out[0]->get_element_type().size();
std::memcpy(out[0]->get_data_ptr(), args[n]->get_data_ptr(), num_bytes);
break;
}
else if (node_op == "BatchNorm")
case OP_TYPEID::BatchNorm:
{
ngraph::op::BatchNorm* bn = dynamic_cast<ngraph::op::BatchNorm*>(&node);
const ngraph::op::BatchNorm* bn = static_cast<const ngraph::op::BatchNorm*>(&node);
if (bn->get_output_size() == 3)
{
reference::batch_norm_three_outputs<T>(
......@@ -321,11 +344,12 @@ private:
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[2]->get_shape());
}
break;
}
else if (node_op == "BatchNormBackprop")
case OP_TYPEID::BatchNormBackprop:
{
ngraph::op::BatchNormBackprop* bn_bprop =
dynamic_cast<ngraph::op::BatchNormBackprop*>(&node);
const ngraph::op::BatchNormBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
......@@ -337,10 +361,11 @@ private:
reinterpret_cast<T*>(out[1]->get_data_ptr()),
reinterpret_cast<T*>(out[2]->get_data_ptr()),
args[2]->get_shape());
break;
}
else if (node_op == "AvgPoolBackprop")
case OP_TYPEID::AvgPoolBackprop:
{
op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node);
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
......@@ -350,10 +375,11 @@ private:
apb->get_padding_below(),
apb->get_padding_above(),
apb->get_include_padding_in_avg_computation());
break;
}
else if (node_op == "Broadcast")
case OP_TYPEID::Broadcast:
{
op::Broadcast* broadcast = dynamic_cast<op::Broadcast*>(&node);
const op::Broadcast* broadcast = static_cast<const op::Broadcast*>(&node);
Shape in_shape = args[0]->get_shape();
Shape out_shape = out[0]->get_shape();
AxisSet broadcast_axes = broadcast->get_broadcast_axes();
......@@ -362,13 +388,15 @@ private:
in_shape,
out_shape,
broadcast_axes);
break;
}
else if (node_op == "Ceiling")
case OP_TYPEID::Ceiling:
{
reference::ceiling<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Concat")
case OP_TYPEID::Concat:
{
const op::Concat* concat = static_cast<const op::Concat*>(&node);
std::vector<const T*> in_args;
......@@ -383,14 +411,16 @@ private:
in_shapes,
out[0]->get_shape(),
concat->get_concatenation_axis());
break;
}
else if (node_op == "Constant")
case OP_TYPEID::Constant:
{
const op::Constant* c = static_cast<const op::Constant*>(&node);
reference::constant<T>(
c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Convert")
case OP_TYPEID::Convert:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
......@@ -466,10 +496,11 @@ private:
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
break;
}
else if (node_op == "Convolution")
case OP_TYPEID::Convolution:
{
auto c = static_cast<const op::Convolution*>(&node);
const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
......@@ -488,10 +519,12 @@ private:
0,
1,
false);
break;
}
else if (node_op == "ConvolutionBackpropFilters")
case OP_TYPEID::ConvolutionBackpropFilters:
{
auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node);
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
......@@ -510,11 +543,13 @@ private:
1,
0,
false);
break;
}
else if (node_op == "ConvolutionBackpropData")
case OP_TYPEID::ConvolutionBackpropData:
{
// Note that args[1] and args[0] are switched here from the usual order.
auto c = static_cast<const op::ConvolutionBackpropData*>(&node);
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(args[1]->get_data_ptr<T>(),
args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
......@@ -533,27 +568,31 @@ private:
0,
1,
true);
break;
}
else if (node_op == "Cos")
case OP_TYPEID::Cos:
{
reference::cos<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Cosh")
case OP_TYPEID::Cosh:
{
reference::cosh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Divide")
case OP_TYPEID::Divide:
{
reference::divide<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Dot")
case OP_TYPEID::Dot:
{
op::Dot* dot = dynamic_cast<op::Dot*>(&node);
const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
......@@ -562,26 +601,29 @@ private:
args[1]->get_shape(),
out[0]->get_shape(),
dot->get_reduction_axes_count());
break;
}
else if (node_op == "Equal")
case OP_TYPEID::Equal:
{
reference::equal<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Exp")
case OP_TYPEID::Exp:
{
reference::exp<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Floor")
case OP_TYPEID::Floor:
{
reference::floor<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "FunctionCall")
case OP_TYPEID::FunctionCall:
{
std::shared_ptr<Function> function = node.get_functions()[0];
......@@ -598,41 +640,47 @@ private:
}
call(function, outputs, inputs);
break;
}
else if (node_op == "Greater")
case OP_TYPEID::Greater:
{
reference::greater<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
break;
}
else if (node_op == "GreaterEq")
case OP_TYPEID::GreaterEq:
{
reference::greater_eq<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Less")
case OP_TYPEID::Less:
{
reference::less<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
break;
}
else if (node_op == "LessEq")
case OP_TYPEID::LessEq:
{
reference::less_eq<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Log")
case OP_TYPEID::Log:
{
reference::log<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "LRN")
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<T>(),
......@@ -642,8 +690,9 @@ private:
lrn->get_beta(),
lrn->get_bias(),
lrn->get_nsize());
break;
}
else if (node_op == "Max")
case OP_TYPEID::Max:
{
const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(args[0]->get_data_ptr<T>(),
......@@ -651,17 +700,19 @@ private:
args[0]->get_shape(),
out[0]->get_shape(),
max->get_reduction_axes());
break;
}
else if (node_op == "Maximum")
case OP_TYPEID::Maximum:
{
reference::maximum<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "MaxPool")
case OP_TYPEID::MaxPool:
{
op::MaxPool* max_pool = dynamic_cast<op::MaxPool*>(&node);
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
......@@ -671,10 +722,12 @@ private:
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
break;
}
else if (node_op == "MaxPoolBackprop")
case OP_TYPEID::MaxPoolBackprop:
{
op::MaxPoolBackprop* max_pool_backprop = dynamic_cast<op::MaxPoolBackprop*>(&node);
const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
......@@ -685,8 +738,9 @@ private:
max_pool_backprop->get_window_movement_strides(),
max_pool_backprop->get_padding_below(),
max_pool_backprop->get_padding_above());
break;
}
else if (node_op == "Min")
case OP_TYPEID::Min:
{
const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(args[0]->get_data_ptr<T>(),
......@@ -694,60 +748,66 @@ private:
args[0]->get_shape(),
out[0]->get_shape(),
min->get_reduction_axes());
break;
}
else if (node_op == "Minimum")
case OP_TYPEID::Minimum:
{
reference::minimum<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Multiply")
case OP_TYPEID::Multiply:
{
reference::multiply<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Negative")
case OP_TYPEID::Negative:
{
reference::negate<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Not")
case OP_TYPEID::Not:
{
reference::logical_not(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "NotEqual")
case OP_TYPEID::NotEqual:
{
reference::not_equal<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
break;
}
else if (node_op == "OneHot")
case OP_TYPEID::OneHot:
{
auto oh = static_cast<const op::OneHot*>(&node);
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
oh->get_one_hot_axis());
break;
}
else if (node_op == "Or")
case OP_TYPEID::Or:
{
reference::logical_or(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Parameter")
case OP_TYPEID::Parameter: break;
case OP_TYPEID::Pad:
{
}
else if (node_op == "Pad")
{
op::Pad* pad = dynamic_cast<op::Pad*>(&node);
const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
......@@ -757,15 +817,17 @@ private:
pad->get_padding_below(),
pad->get_padding_above(),
pad->get_padding_interior());
break;
}
else if (node_op == "Power")
case OP_TYPEID::Power:
{
reference::power<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Product")
case OP_TYPEID::Product:
{
const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(args[0]->get_data_ptr<T>(),
......@@ -773,10 +835,11 @@ private:
args[0]->get_shape(),
out[0]->get_shape(),
product->get_reduction_axes());
break;
}
else if (node_op == "Reduce")
case OP_TYPEID::Reduce:
{
op::Reduce* reduce = dynamic_cast<op::Reduce*>(&node);
const op::Reduce* reduce = static_cast<const op::Reduce*>(&node);
std::shared_ptr<Function> reduction_function = reduce->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
......@@ -799,10 +862,11 @@ private:
node.get_output_shape(0),
reduce->get_reduction_axes(),
f);
break;
}
else if (node_op == "ReduceWindow")
case OP_TYPEID::ReduceWindow:
{
op::ReduceWindow* reduce_window = dynamic_cast<op::ReduceWindow*>(&node);
const op::ReduceWindow* reduce_window = static_cast<const op::ReduceWindow*>(&node);
std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
......@@ -826,20 +890,23 @@ private:
f,
reduce_window->get_window_shape(),
reduce_window->get_window_movement_strides());
break;
}
else if (node_op == "Relu")
case OP_TYPEID::Relu:
{
reference::relu<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "ReluBackprop")
case OP_TYPEID::ReluBackprop:
{
reference::relu_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "ReplaceSlice")
case OP_TYPEID::ReplaceSlice:
{
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(args[0]->get_data_ptr<T>(),
......@@ -850,35 +917,39 @@ private:
slice->get_upper_bounds(),
slice->get_strides(),
out[0]->get_shape());
break;
}
else if (node_op == "Reshape")
case OP_TYPEID::Reshape:
{
op::Reshape* reshape = dynamic_cast<op::Reshape*>(&node);
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
reference::reshape(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
reshape->get_input_order(),
out[0]->get_shape());
break;
}
else if (node_op == "Result")
case OP_TYPEID::Result:
{
op::Result* res = dynamic_cast<op::Result*>(&node);
const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
shape_size(res->get_shape()));
break;
}
else if (node_op == "Reverse")
case OP_TYPEID::Reverse:
{
op::Reverse* reverse = dynamic_cast<op::Reverse*>(&node);
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reverse->get_reversed_axes());
break;
}
else if (node_op == "ReverseSequence")
case OP_TYPEID::ReverseSequence:
{
op::ReverseSequence* reverse = dynamic_cast<op::ReverseSequence*>(&node);
const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
if (args[1]->get_element_type() == element::i32)
{
......@@ -893,19 +964,21 @@ private:
{
throw ngraph_error("only int32 indices are supported");
}
break;
}
else if (node_op == "Select")
case OP_TYPEID::Select:
{
reference::select<T>(args[0]->get_data_ptr<char>(),
args[1]->get_data_ptr<T>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "SelectAndScatter")
case OP_TYPEID::SelectAndScatter:
{
ngraph::op::SelectAndScatter* select_and_scatter =
dynamic_cast<ngraph::op::SelectAndScatter*>(&node);
const ngraph::op::SelectAndScatter* select_and_scatter =
static_cast<const ngraph::op::SelectAndScatter*>(&node);
std::shared_ptr<ngraph::Function> selection_function =
select_and_scatter->get_functions()[0];
......@@ -949,35 +1022,41 @@ private:
f_scatter,
select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides());
break;
}
else if (node_op == "Sigmoid")
case OP_TYPEID::Sigmoid:
{
reference::sigmoid<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "SigmoidBackprop")
case OP_TYPEID::SigmoidBackprop:
{
reference::sigmoid_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Sign")
case OP_TYPEID::Sign:
{
reference::sign<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Sin")
case OP_TYPEID::Sin:
{
reference::sin<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Sinh")
case OP_TYPEID::Sinh:
{
reference::sinh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Slice")
case OP_TYPEID::Slice:
{
const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(args[0]->get_data_ptr<T>(),
......@@ -987,28 +1066,37 @@ private:
slice->get_upper_bounds(),
slice->get_strides(),
out[0]->get_shape());
break;
}
else if (node_op == "Softmax")
case OP_TYPEID::Softmax:
{
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_shape(),
softmax->get_axes());
break;
}
else if (node_op == "Sqrt")
case OP_TYPEID::Sqrt:
{
reference::sqrt<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Subtract")
case OP_TYPEID::StopGradient:
{
// TODO: Throw a real unsupported_op when available
throw std::runtime_error("Unsupported op 'StopGradient'");
}
case OP_TYPEID::Subtract:
{
reference::subtract<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
break;
}
else if (node_op == "Sum")
case OP_TYPEID::Sum:
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(args[0]->get_data_ptr<T>(),
......@@ -1016,18 +1104,21 @@ private:
args[0]->get_shape(),
out[0]->get_shape(),
sum->get_reduction_axes());
break;
}
else if (node_op == "Tan")
case OP_TYPEID::Tan:
{
reference::tan<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "Tanh")
case OP_TYPEID::Tanh:
{
reference::tanh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
else if (node_op == "TopK")
case OP_TYPEID::TopK:
{
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (out[0]->get_element_type() == element::i64)
......@@ -1057,11 +1148,7 @@ private:
throw ngraph_error("Unexpected type");
}
}
else
{
std::stringstream ss;
ss << "unsupported op " << node_op;
throw ngraph_error(ss.str());
#pragma GCC diagnostic pop
}
}
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::interpreter::OP_TYPEID::Abs},
// {"Acos", runtime::interpreter::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
// TODO: use unsupported_op when that is merged to master
throw runtime_error(m_node->description());
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a) a,
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::interpreter::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; }
ngraph::runtime::interpreter::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment