Unverified Commit d81d0c93 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Interpreter use switch() for main loop (#1538)

* wip

* interperter use switch instead of if/else

* more cleanup

* make nop elimination run on all backends

* revert

* use single include file to define all ops so there is only one instance

* move op.tbl to ngraph/op dir as it is useful. Added useage example.

* add some comments where needed

* revert some changes to reduce delta

* add const

* add more const

* simplify using NodeWrapper

* update per review comments

* update per review comments

* update per review comments

* remove switch warning as it is not supported in older gcc
parent 5032f343
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each op. If an op is added it must be
// added to this list.
//
// In order to use this list you want to define a macro named exactly NGRAPH_OP
// When you are done you should undef the macro
// As an example if you wanted to make a list of all op names as strings you could do this:
//
// #define NGRAPH_OP(a) #a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// "Abs",
// "Acos",
// ...
//
// It's that easy. You can use this for fun and profit.
NGRAPH_OP(Abs)
NGRAPH_OP(Acos)
NGRAPH_OP(Add)
NGRAPH_OP(AllReduce)
NGRAPH_OP(And)
NGRAPH_OP(ArgMax)
NGRAPH_OP(ArgMin)
NGRAPH_OP(Asin)
NGRAPH_OP(Atan)
NGRAPH_OP(AvgPool)
NGRAPH_OP(AvgPoolBackprop)
NGRAPH_OP(BatchNorm)
NGRAPH_OP(BatchNormBackprop)
NGRAPH_OP(Broadcast)
NGRAPH_OP(Ceiling)
NGRAPH_OP(Concat)
NGRAPH_OP(Constant)
NGRAPH_OP(Convert)
NGRAPH_OP(Convolution)
NGRAPH_OP(ConvolutionBackpropData)
NGRAPH_OP(ConvolutionBackpropFilters)
NGRAPH_OP(Cos)
NGRAPH_OP(Cosh)
NGRAPH_OP(Divide)
NGRAPH_OP(Dot)
NGRAPH_OP(Equal)
NGRAPH_OP(Exp)
NGRAPH_OP(Floor)
NGRAPH_OP(FunctionCall)
NGRAPH_OP(GetOutputElement)
NGRAPH_OP(Greater)
NGRAPH_OP(GreaterEq)
NGRAPH_OP(Less)
NGRAPH_OP(LessEq)
NGRAPH_OP(Log)
NGRAPH_OP(LRN)
NGRAPH_OP(Max)
NGRAPH_OP(Maximum)
NGRAPH_OP(MaxPool)
NGRAPH_OP(MaxPoolBackprop)
NGRAPH_OP(Min)
NGRAPH_OP(Minimum)
NGRAPH_OP(Multiply)
NGRAPH_OP(Negative)
NGRAPH_OP(Not)
NGRAPH_OP(NotEqual)
NGRAPH_OP(OneHot)
NGRAPH_OP(Or)
NGRAPH_OP(Pad)
NGRAPH_OP(Parameter)
NGRAPH_OP(Power)
NGRAPH_OP(Product)
NGRAPH_OP(Reduce)
NGRAPH_OP(ReduceWindow)
NGRAPH_OP(Relu)
NGRAPH_OP(ReluBackprop)
NGRAPH_OP(ReplaceSlice)
NGRAPH_OP(Reshape)
NGRAPH_OP(Result)
NGRAPH_OP(Reverse)
NGRAPH_OP(ReverseSequence)
NGRAPH_OP(Select)
NGRAPH_OP(SelectAndScatter)
NGRAPH_OP(Sigmoid)
NGRAPH_OP(SigmoidBackprop)
NGRAPH_OP(Sign)
NGRAPH_OP(Sin)
NGRAPH_OP(Sinh)
NGRAPH_OP(Slice)
NGRAPH_OP(Softmax)
NGRAPH_OP(Sqrt)
NGRAPH_OP(StopGradient)
NGRAPH_OP(Subtract)
NGRAPH_OP(Sum)
NGRAPH_OP(Tan)
NGRAPH_OP(Tanh)
NGRAPH_OP(TopK)
...@@ -174,30 +174,9 @@ map<string, string> runtime::BackendManager::get_registered_device_map() ...@@ -174,30 +174,9 @@ map<string, string> runtime::BackendManager::get_registered_device_map()
string name = file_util::get_file_name(file); string name = file_util::get_file_name(file);
string backend_name; string backend_name;
if (is_backend_name(name, backend_name)) if (is_backend_name(name, backend_name))
{
DL_HANDLE handle;
#ifdef WIN32
handle = LoadLibrary(file.c_str());
#else
handle = dlopen(file.c_str(), RTLD_LAZY | RTLD_LOCAL);
#endif
if (handle)
{
if (DLSYM(handle, "new_backend") && DLSYM(handle, "delete_backend"))
{
function<const char*()> get_ngraph_version_string =
reinterpret_cast<const char* (*)()>(
DLSYM(handle, "get_ngraph_version_string"));
if (get_ngraph_version_string &&
get_ngraph_version_string() == string(NGRAPH_VERSION))
{ {
rc.insert({to_upper(backend_name), file}); rc.insert({to_upper(backend_name), file});
} }
}
CLOSE_LIBRARY(handle);
}
}
}; };
file_util::iterate_files(my_directory, f, false, true); file_util::iterate_files(my_directory, f, false, true);
return rc; return rc;
......
...@@ -23,8 +23,8 @@ if(NGRAPH_DISTRIBUTED_ENABLE) ...@@ -23,8 +23,8 @@ if(NGRAPH_DISTRIBUTED_ENABLE)
endif() endif()
if (NGRAPH_INTERPRETER_ENABLE) if (NGRAPH_INTERPRETER_ENABLE)
add_library(interpreter_backend SHARED int_backend.cpp) add_library(interpreter_backend SHARED int_backend.cpp node_wrapper.cpp)
set_target_properties(interpreter_backend PROPERTIES VERSION ${NGRAPH_VERSION} SOVERSION ${NGRAPH_API_VERSION}) set_target_properties(interpreter_backend PROPERTIES VERSION ${NGRAPH_VERSION})
target_link_libraries(interpreter_backend PUBLIC ngraph) target_link_libraries(interpreter_backend PUBLIC ngraph)
set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/runtime/interpreter/int_backend.hpp" #include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
...@@ -68,6 +69,11 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function) ...@@ -68,6 +69,11 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function)
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function); pass_manager.run_passes(function);
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
instance.m_wrapped_nodes.emplace_back(node);
}
} }
return true; return true;
...@@ -125,13 +131,14 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -125,13 +131,14 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
} }
// for each ordered op in the graph // for each ordered op in the graph
for (shared_ptr<Node> op : function->get_ordered_ops()) for (const NodeWrapper& wrapped : instance.m_wrapped_nodes)
{ {
const Node* op = &wrapped.get_node();
auto type_id = wrapped.get_typeid();
if (op->description() == "Parameter") if (op->description() == "Parameter")
{ {
continue; continue;
} }
// get op inputs from map // get op inputs from map
vector<shared_ptr<runtime::HostTensorView>> op_inputs; vector<shared_ptr<runtime::HostTensorView>> op_inputs;
for (const descriptor::Input& input : op->get_inputs()) for (const descriptor::Input& input : op->get_inputs())
...@@ -164,35 +171,37 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -164,35 +171,37 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
// get op type // get op type
element::Type type; element::Type type;
if (dynamic_pointer_cast<op::util::BinaryElementwiseComparison>(op) || switch (type_id)
dynamic_pointer_cast<op::Select>(op))
{ {
case OP_TYPEID::Convert:
type = op->get_inputs().at(0).get_tensor().get_element_type();
break;
case OP_TYPEID::Equal:
case OP_TYPEID::Greater:
case OP_TYPEID::GreaterEq:
case OP_TYPEID::Less:
case OP_TYPEID::LessEq:
case OP_TYPEID::NotEqual:
// Get the type of the second input, not the first // Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs // All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second // Select has bool for first input and the type we are interested in for the second
type = op->get_inputs().at(1).get_tensor().get_element_type(); type = op->get_inputs().at(1).get_tensor().get_element_type();
} break;
else if (dynamic_pointer_cast<op::Convert>(op)) default: type = op->get_outputs().at(0).get_element_type(); break;
{
type = op->get_inputs().at(0).get_tensor().get_element_type();
}
else
{
type = op->get_outputs().at(0).get_element_type();
} }
if (instance.m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
instance.m_timer_map[op.get()].start(); instance.m_timer_map[op].start();
} }
generate_calls(type, *op, op_outputs, op_inputs); generate_calls(type, wrapped, op_outputs, op_inputs);
if (instance.m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
instance.m_timer_map[op.get()].stop(); instance.m_timer_map[op].stop();
} }
if (instance.m_nan_check_enabled) if (instance.m_nan_check_enabled)
{ {
perform_nan_check(op_outputs, op.get()); perform_nan_check(op_outputs, op);
} }
// delete any obsolete tensors // delete any obsolete tensors
...@@ -214,7 +223,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -214,7 +223,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls( void runtime::interpreter::INTBackend::generate_calls(
const element::Type& type, const element::Type& type,
Node& op, const NodeWrapper& op,
const vector<shared_ptr<HostTensorView>>& outputs, const vector<shared_ptr<HostTensorView>>& outputs,
const vector<shared_ptr<HostTensorView>>& inputs) const vector<shared_ptr<HostTensorView>>& inputs)
{ {
...@@ -265,7 +274,7 @@ void runtime::interpreter::INTBackend::generate_calls( ...@@ -265,7 +274,7 @@ void runtime::interpreter::INTBackend::generate_calls(
else else
{ {
stringstream ss; stringstream ss;
ss << "unsupported element type " << type << " op " << op.get_name(); ss << "unsupported element type " << type << " op " << op.get_node().get_name();
throw ngraph_error(ss.str()); throw ngraph_error(ss.str());
} }
} }
......
...@@ -21,10 +21,6 @@ ...@@ -21,10 +21,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
...@@ -49,12 +45,15 @@ ...@@ -49,12 +45,15 @@
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp" #include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
...@@ -121,6 +120,7 @@ ...@@ -121,6 +120,7 @@
#include "ngraph/runtime/reference/tan.hpp" #include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp" #include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp" #include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp" #include "ngraph/runtime/reference/allreduce.hpp"
...@@ -136,6 +136,7 @@ namespace ngraph ...@@ -136,6 +136,7 @@ namespace ngraph
} }
} }
} }
class ngraph::runtime::interpreter::INTBackend : public Backend class ngraph::runtime::interpreter::INTBackend : public Backend
{ {
public: public:
...@@ -165,6 +166,7 @@ private: ...@@ -165,6 +166,7 @@ private:
bool m_nan_check_enabled = false; bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false; bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map; std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
}; };
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map; std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
...@@ -172,50 +174,65 @@ private: ...@@ -172,50 +174,65 @@ private:
const Node* op = nullptr); const Node* op = nullptr);
void generate_calls(const element::Type& type, void generate_calls(const element::Type& type,
Node& op, const NodeWrapper& op,
const std::vector<std::shared_ptr<HostTensorView>>& outputs, const std::vector<std::shared_ptr<HostTensorView>>& outputs,
const std::vector<std::shared_ptr<HostTensorView>>& inputs); const std::vector<std::shared_ptr<HostTensorView>>& inputs);
template <typename T> template <typename T>
void op_engine(Node& node, void op_engine(const NodeWrapper& node_wrapper,
const std::vector<std::shared_ptr<HostTensorView>>& out, const std::vector<std::shared_ptr<HostTensorView>>& out,
const std::vector<std::shared_ptr<HostTensorView>>& args) const std::vector<std::shared_ptr<HostTensorView>>& args)
{ {
const Node& node = node_wrapper.get_node();
std::string node_op = node.description(); std::string node_op = node.description();
if (node_op == "Abs")
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wcovered-switch-default"
switch (node_wrapper.get_typeid())
{
case OP_TYPEID::Abs:
{ {
reference::abs<T>( reference::abs<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Acos") case OP_TYPEID::Acos:
{ {
reference::acos<T>( reference::acos<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Add") case OP_TYPEID::Add:
{ {
reference::add<T>(args[0]->get_data_ptr<T>(), reference::add<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
else if (node_op == "AllReduce")
{
reference::allreduce<T>(args[0]->get_data_ptr<T>(), reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
args[0]->get_element_type(), args[0]->get_element_type(),
static_cast<int>(args[0]->get_element_count())); static_cast<int>(args[0]->get_element_count()));
}
#endif #endif
else if (node_op == "And") break;
}
case OP_TYPEID::And:
{ {
reference::logical_and(args[0]->get_data_ptr<T>(), reference::logical_and(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "ArgMin") case OP_TYPEID::ArgMin:
{ {
const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node); const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node);
if (out[0]->get_element_type() == element::i64) if (out[0]->get_element_type() == element::i64)
...@@ -238,8 +255,9 @@ private: ...@@ -238,8 +255,9 @@ private:
{ {
throw ngraph_error("Unexpected type"); throw ngraph_error("Unexpected type");
} }
break;
} }
else if (node_op == "ArgMax") case OP_TYPEID::ArgMax:
{ {
const op::ArgMax* argmax = static_cast<const op::ArgMax*>(&node); const op::ArgMax* argmax = static_cast<const op::ArgMax*>(&node);
if (out[0]->get_element_type() == element::i64) if (out[0]->get_element_type() == element::i64)
...@@ -262,20 +280,23 @@ private: ...@@ -262,20 +280,23 @@ private:
{ {
throw ngraph_error("Unexpected type"); throw ngraph_error("Unexpected type");
} }
break;
} }
else if (node_op == "Asin") case OP_TYPEID::Asin:
{ {
reference::asin<T>( reference::asin<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Atan") case OP_TYPEID::Atan:
{ {
reference::atan<T>( reference::atan<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "AvgPool") case OP_TYPEID::AvgPool:
{ {
op::AvgPool* avg_pool = dynamic_cast<op::AvgPool*>(&node); const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(args[0]->get_data_ptr<T>(), reference::avg_pool<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
...@@ -286,18 +307,20 @@ private: ...@@ -286,18 +307,20 @@ private:
avg_pool->get_padding_below(), avg_pool->get_padding_below(),
avg_pool->get_padding_above(), avg_pool->get_padding_above(),
avg_pool->get_include_padding_in_avg_computation()); avg_pool->get_include_padding_in_avg_computation());
break;
} }
else if (node_op == "GetOutputElement") case OP_TYPEID::GetOutputElement:
{ {
const op::GetOutputElement* get_output_element = const op::GetOutputElement* get_output_element =
static_cast<const op::GetOutputElement*>(&node); static_cast<const op::GetOutputElement*>(&node);
size_t n = get_output_element->get_n(); size_t n = get_output_element->get_n();
size_t num_bytes = out[0]->get_element_count() * out[0]->get_element_type().size(); size_t num_bytes = out[0]->get_element_count() * out[0]->get_element_type().size();
std::memcpy(out[0]->get_data_ptr(), args[n]->get_data_ptr(), num_bytes); std::memcpy(out[0]->get_data_ptr(), args[n]->get_data_ptr(), num_bytes);
break;
} }
else if (node_op == "BatchNorm") case OP_TYPEID::BatchNorm:
{ {
ngraph::op::BatchNorm* bn = dynamic_cast<ngraph::op::BatchNorm*>(&node); const ngraph::op::BatchNorm* bn = static_cast<const ngraph::op::BatchNorm*>(&node);
if (bn->get_output_size() == 3) if (bn->get_output_size() == 3)
{ {
reference::batch_norm_three_outputs<T>( reference::batch_norm_three_outputs<T>(
...@@ -321,11 +344,12 @@ private: ...@@ -321,11 +344,12 @@ private:
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[2]->get_shape()); args[2]->get_shape());
} }
break;
} }
else if (node_op == "BatchNormBackprop") case OP_TYPEID::BatchNormBackprop:
{ {
ngraph::op::BatchNormBackprop* bn_bprop = const ngraph::op::BatchNormBackprop* bn_bprop =
dynamic_cast<ngraph::op::BatchNormBackprop*>(&node); static_cast<const ngraph::op::BatchNormBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(), reference::batch_norm_backprop(bn_bprop->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()), reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
...@@ -337,10 +361,11 @@ private: ...@@ -337,10 +361,11 @@ private:
reinterpret_cast<T*>(out[1]->get_data_ptr()), reinterpret_cast<T*>(out[1]->get_data_ptr()),
reinterpret_cast<T*>(out[2]->get_data_ptr()), reinterpret_cast<T*>(out[2]->get_data_ptr()),
args[2]->get_shape()); args[2]->get_shape());
break;
} }
else if (node_op == "AvgPoolBackprop") case OP_TYPEID::AvgPoolBackprop:
{ {
op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node); const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(args[0]->get_data_ptr<T>(), reference::avg_pool_backprop<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
...@@ -350,10 +375,11 @@ private: ...@@ -350,10 +375,11 @@ private:
apb->get_padding_below(), apb->get_padding_below(),
apb->get_padding_above(), apb->get_padding_above(),
apb->get_include_padding_in_avg_computation()); apb->get_include_padding_in_avg_computation());
break;
} }
else if (node_op == "Broadcast") case OP_TYPEID::Broadcast:
{ {
op::Broadcast* broadcast = dynamic_cast<op::Broadcast*>(&node); const op::Broadcast* broadcast = static_cast<const op::Broadcast*>(&node);
Shape in_shape = args[0]->get_shape(); Shape in_shape = args[0]->get_shape();
Shape out_shape = out[0]->get_shape(); Shape out_shape = out[0]->get_shape();
AxisSet broadcast_axes = broadcast->get_broadcast_axes(); AxisSet broadcast_axes = broadcast->get_broadcast_axes();
...@@ -362,13 +388,15 @@ private: ...@@ -362,13 +388,15 @@ private:
in_shape, in_shape,
out_shape, out_shape,
broadcast_axes); broadcast_axes);
break;
} }
else if (node_op == "Ceiling") case OP_TYPEID::Ceiling:
{ {
reference::ceiling<T>( reference::ceiling<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Concat") case OP_TYPEID::Concat:
{ {
const op::Concat* concat = static_cast<const op::Concat*>(&node); const op::Concat* concat = static_cast<const op::Concat*>(&node);
std::vector<const T*> in_args; std::vector<const T*> in_args;
...@@ -383,14 +411,16 @@ private: ...@@ -383,14 +411,16 @@ private:
in_shapes, in_shapes,
out[0]->get_shape(), out[0]->get_shape(),
concat->get_concatenation_axis()); concat->get_concatenation_axis());
break;
} }
else if (node_op == "Constant") case OP_TYPEID::Constant:
{ {
const op::Constant* c = static_cast<const op::Constant*>(&node); const op::Constant* c = static_cast<const op::Constant*>(&node);
reference::constant<T>( reference::constant<T>(
c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Convert") case OP_TYPEID::Convert:
{ {
// const op::Convert* c = static_cast<const op::Convert*>(&node); // const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type(); element::Type type = node.get_element_type();
...@@ -466,10 +496,11 @@ private: ...@@ -466,10 +496,11 @@ private:
ss << "unsupported element type " << type << " op Convert"; ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
break;
} }
else if (node_op == "Convolution") case OP_TYPEID::Convolution:
{ {
auto c = static_cast<const op::Convolution*>(&node); const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<T>(), reference::convolution<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
...@@ -488,10 +519,12 @@ private: ...@@ -488,10 +519,12 @@ private:
0, 0,
1, 1,
false); false);
break;
} }
else if (node_op == "ConvolutionBackpropFilters") case OP_TYPEID::ConvolutionBackpropFilters:
{ {
auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node); const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<T>(), reference::convolution<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
...@@ -510,11 +543,13 @@ private: ...@@ -510,11 +543,13 @@ private:
1, 1,
0, 0,
false); false);
break;
} }
else if (node_op == "ConvolutionBackpropData") case OP_TYPEID::ConvolutionBackpropData:
{ {
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
auto c = static_cast<const op::ConvolutionBackpropData*>(&node); const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(args[1]->get_data_ptr<T>(), reference::convolution<T>(args[1]->get_data_ptr<T>(),
args[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
...@@ -533,27 +568,31 @@ private: ...@@ -533,27 +568,31 @@ private:
0, 0,
1, 1,
true); true);
break;
} }
else if (node_op == "Cos") case OP_TYPEID::Cos:
{ {
reference::cos<T>( reference::cos<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Cosh") case OP_TYPEID::Cosh:
{ {
reference::cosh<T>( reference::cosh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Divide") case OP_TYPEID::Divide:
{ {
reference::divide<T>(args[0]->get_data_ptr<T>(), reference::divide<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Dot") case OP_TYPEID::Dot:
{ {
op::Dot* dot = dynamic_cast<op::Dot*>(&node); const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(args[0]->get_data_ptr<T>(), reference::dot(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
...@@ -562,26 +601,29 @@ private: ...@@ -562,26 +601,29 @@ private:
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
dot->get_reduction_axes_count()); dot->get_reduction_axes_count());
break;
} }
case OP_TYPEID::Equal:
else if (node_op == "Equal")
{ {
reference::equal<T>(args[0]->get_data_ptr<T>(), reference::equal<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Exp") case OP_TYPEID::Exp:
{ {
reference::exp<T>( reference::exp<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Floor") case OP_TYPEID::Floor:
{ {
reference::floor<T>( reference::floor<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "FunctionCall") case OP_TYPEID::FunctionCall:
{ {
std::shared_ptr<Function> function = node.get_functions()[0]; std::shared_ptr<Function> function = node.get_functions()[0];
...@@ -598,41 +640,47 @@ private: ...@@ -598,41 +640,47 @@ private:
} }
call(function, outputs, inputs); call(function, outputs, inputs);
break;
} }
else if (node_op == "Greater") case OP_TYPEID::Greater:
{ {
reference::greater<T>(args[0]->get_data_ptr<T>(), reference::greater<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "GreaterEq") case OP_TYPEID::GreaterEq:
{ {
reference::greater_eq<T>(args[0]->get_data_ptr<T>(), reference::greater_eq<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Less") case OP_TYPEID::Less:
{ {
reference::less<T>(args[0]->get_data_ptr<T>(), reference::less<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "LessEq") case OP_TYPEID::LessEq:
{ {
reference::less_eq<T>(args[0]->get_data_ptr<T>(), reference::less_eq<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Log") case OP_TYPEID::Log:
{ {
reference::log<T>( reference::log<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "LRN") case OP_TYPEID::LRN:
{ {
const op::LRN* lrn = static_cast<const op::LRN*>(&node); const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<T>(), reference::lrn<T>(args[0]->get_data_ptr<T>(),
...@@ -642,8 +690,9 @@ private: ...@@ -642,8 +690,9 @@ private:
lrn->get_beta(), lrn->get_beta(),
lrn->get_bias(), lrn->get_bias(),
lrn->get_nsize()); lrn->get_nsize());
break;
} }
else if (node_op == "Max") case OP_TYPEID::Max:
{ {
const op::Max* max = static_cast<const op::Max*>(&node); const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(args[0]->get_data_ptr<T>(), reference::max<T>(args[0]->get_data_ptr<T>(),
...@@ -651,17 +700,19 @@ private: ...@@ -651,17 +700,19 @@ private:
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
max->get_reduction_axes()); max->get_reduction_axes());
break;
} }
else if (node_op == "Maximum") case OP_TYPEID::Maximum:
{ {
reference::maximum<T>(args[0]->get_data_ptr<T>(), reference::maximum<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "MaxPool") case OP_TYPEID::MaxPool:
{ {
op::MaxPool* max_pool = dynamic_cast<op::MaxPool*>(&node); const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(args[0]->get_data_ptr<T>(), reference::max_pool<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
...@@ -671,10 +722,12 @@ private: ...@@ -671,10 +722,12 @@ private:
max_pool->get_window_movement_strides(), max_pool->get_window_movement_strides(),
max_pool->get_padding_below(), max_pool->get_padding_below(),
max_pool->get_padding_above()); max_pool->get_padding_above());
break;
} }
else if (node_op == "MaxPoolBackprop") case OP_TYPEID::MaxPoolBackprop:
{ {
op::MaxPoolBackprop* max_pool_backprop = dynamic_cast<op::MaxPoolBackprop*>(&node); const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(args[0]->get_data_ptr<T>(), reference::max_pool_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
...@@ -685,8 +738,9 @@ private: ...@@ -685,8 +738,9 @@ private:
max_pool_backprop->get_window_movement_strides(), max_pool_backprop->get_window_movement_strides(),
max_pool_backprop->get_padding_below(), max_pool_backprop->get_padding_below(),
max_pool_backprop->get_padding_above()); max_pool_backprop->get_padding_above());
break;
} }
else if (node_op == "Min") case OP_TYPEID::Min:
{ {
const op::Min* min = static_cast<const op::Min*>(&node); const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(args[0]->get_data_ptr<T>(), reference::min<T>(args[0]->get_data_ptr<T>(),
...@@ -694,60 +748,66 @@ private: ...@@ -694,60 +748,66 @@ private:
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
min->get_reduction_axes()); min->get_reduction_axes());
break;
} }
else if (node_op == "Minimum") case OP_TYPEID::Minimum:
{ {
reference::minimum<T>(args[0]->get_data_ptr<T>(), reference::minimum<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Multiply") case OP_TYPEID::Multiply:
{ {
reference::multiply<T>(args[0]->get_data_ptr<T>(), reference::multiply<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Negative") case OP_TYPEID::Negative:
{ {
reference::negate<T>( reference::negate<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Not") case OP_TYPEID::Not:
{ {
reference::logical_not( reference::logical_not(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "NotEqual") case OP_TYPEID::NotEqual:
{ {
reference::not_equal<T>(args[0]->get_data_ptr<T>(), reference::not_equal<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "OneHot") case OP_TYPEID::OneHot:
{ {
auto oh = static_cast<const op::OneHot*>(&node); const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(args[0]->get_data_ptr<T>(), reference::one_hot<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
oh->get_one_hot_axis()); oh->get_one_hot_axis());
break;
} }
else if (node_op == "Or") case OP_TYPEID::Or:
{ {
reference::logical_or(args[0]->get_data_ptr<T>(), reference::logical_or(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Parameter") case OP_TYPEID::Parameter: break;
{ case OP_TYPEID::Pad:
}
else if (node_op == "Pad")
{ {
op::Pad* pad = dynamic_cast<op::Pad*>(&node); const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(args[0]->get_data_ptr<T>(), reference::pad(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
...@@ -757,15 +817,17 @@ private: ...@@ -757,15 +817,17 @@ private:
pad->get_padding_below(), pad->get_padding_below(),
pad->get_padding_above(), pad->get_padding_above(),
pad->get_padding_interior()); pad->get_padding_interior());
break;
} }
else if (node_op == "Power") case OP_TYPEID::Power:
{ {
reference::power<T>(args[0]->get_data_ptr<T>(), reference::power<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Product") case OP_TYPEID::Product:
{ {
const op::Product* product = static_cast<const op::Product*>(&node); const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(args[0]->get_data_ptr<T>(), reference::product<T>(args[0]->get_data_ptr<T>(),
...@@ -773,10 +835,11 @@ private: ...@@ -773,10 +835,11 @@ private:
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
product->get_reduction_axes()); product->get_reduction_axes());
break;
} }
else if (node_op == "Reduce") case OP_TYPEID::Reduce:
{ {
op::Reduce* reduce = dynamic_cast<op::Reduce*>(&node); const op::Reduce* reduce = static_cast<const op::Reduce*>(&node);
std::shared_ptr<Function> reduction_function = reduce->get_functions()[0]; std::shared_ptr<Function> reduction_function = reduce->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T { std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
...@@ -799,10 +862,11 @@ private: ...@@ -799,10 +862,11 @@ private:
node.get_output_shape(0), node.get_output_shape(0),
reduce->get_reduction_axes(), reduce->get_reduction_axes(),
f); f);
break;
} }
else if (node_op == "ReduceWindow") case OP_TYPEID::ReduceWindow:
{ {
op::ReduceWindow* reduce_window = dynamic_cast<op::ReduceWindow*>(&node); const op::ReduceWindow* reduce_window = static_cast<const op::ReduceWindow*>(&node);
std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0]; std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T { std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
...@@ -826,20 +890,23 @@ private: ...@@ -826,20 +890,23 @@ private:
f, f,
reduce_window->get_window_shape(), reduce_window->get_window_shape(),
reduce_window->get_window_movement_strides()); reduce_window->get_window_movement_strides());
break;
} }
else if (node_op == "Relu") case OP_TYPEID::Relu:
{ {
reference::relu<T>( reference::relu<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "ReluBackprop") case OP_TYPEID::ReluBackprop:
{ {
reference::relu_backprop<T>(args[0]->get_data_ptr<T>(), reference::relu_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "ReplaceSlice") case OP_TYPEID::ReplaceSlice:
{ {
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node); const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(args[0]->get_data_ptr<T>(), reference::replace_slice<T>(args[0]->get_data_ptr<T>(),
...@@ -850,35 +917,39 @@ private: ...@@ -850,35 +917,39 @@ private:
slice->get_upper_bounds(), slice->get_upper_bounds(),
slice->get_strides(), slice->get_strides(),
out[0]->get_shape()); out[0]->get_shape());
break;
} }
else if (node_op == "Reshape") case OP_TYPEID::Reshape:
{ {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(&node); const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
reference::reshape(args[0]->get_data_ptr<T>(), reference::reshape(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
reshape->get_input_order(), reshape->get_input_order(),
out[0]->get_shape()); out[0]->get_shape());
break;
} }
else if (node_op == "Result") case OP_TYPEID::Result:
{ {
op::Result* res = dynamic_cast<op::Result*>(&node); const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(args[0]->get_data_ptr<T>(), reference::result(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
shape_size(res->get_shape())); shape_size(res->get_shape()));
break;
} }
else if (node_op == "Reverse") case OP_TYPEID::Reverse:
{ {
op::Reverse* reverse = dynamic_cast<op::Reverse*>(&node); const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(args[0]->get_data_ptr<T>(), reference::reverse(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
reverse->get_reversed_axes()); reverse->get_reversed_axes());
break;
} }
else if (node_op == "ReverseSequence") case OP_TYPEID::ReverseSequence:
{ {
op::ReverseSequence* reverse = dynamic_cast<op::ReverseSequence*>(&node); const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
if (args[1]->get_element_type() == element::i32) if (args[1]->get_element_type() == element::i32)
{ {
...@@ -893,19 +964,21 @@ private: ...@@ -893,19 +964,21 @@ private:
{ {
throw ngraph_error("only int32 indices are supported"); throw ngraph_error("only int32 indices are supported");
} }
break;
} }
else if (node_op == "Select") case OP_TYPEID::Select:
{ {
reference::select<T>(args[0]->get_data_ptr<char>(), reference::select<T>(args[0]->get_data_ptr<char>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
args[2]->get_data_ptr<T>(), args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "SelectAndScatter") case OP_TYPEID::SelectAndScatter:
{ {
ngraph::op::SelectAndScatter* select_and_scatter = const ngraph::op::SelectAndScatter* select_and_scatter =
dynamic_cast<ngraph::op::SelectAndScatter*>(&node); static_cast<const ngraph::op::SelectAndScatter*>(&node);
std::shared_ptr<ngraph::Function> selection_function = std::shared_ptr<ngraph::Function> selection_function =
select_and_scatter->get_functions()[0]; select_and_scatter->get_functions()[0];
...@@ -949,35 +1022,41 @@ private: ...@@ -949,35 +1022,41 @@ private:
f_scatter, f_scatter,
select_and_scatter->get_window_shape(), select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides()); select_and_scatter->get_window_movement_strides());
break;
} }
else if (node_op == "Sigmoid") case OP_TYPEID::Sigmoid:
{ {
reference::sigmoid<T>( reference::sigmoid<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "SigmoidBackprop") case OP_TYPEID::SigmoidBackprop:
{ {
reference::sigmoid_backprop<T>(args[0]->get_data_ptr<T>(), reference::sigmoid_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Sign") case OP_TYPEID::Sign:
{ {
reference::sign<T>( reference::sign<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Sin") case OP_TYPEID::Sin:
{ {
reference::sin<T>( reference::sin<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Sinh") case OP_TYPEID::Sinh:
{ {
reference::sinh<T>( reference::sinh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Slice") case OP_TYPEID::Slice:
{ {
const op::Slice* slice = static_cast<const op::Slice*>(&node); const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(args[0]->get_data_ptr<T>(), reference::slice<T>(args[0]->get_data_ptr<T>(),
...@@ -987,28 +1066,37 @@ private: ...@@ -987,28 +1066,37 @@ private:
slice->get_upper_bounds(), slice->get_upper_bounds(),
slice->get_strides(), slice->get_strides(),
out[0]->get_shape()); out[0]->get_shape());
break;
} }
else if (node_op == "Softmax") case OP_TYPEID::Softmax:
{ {
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node); const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(args[0]->get_data_ptr<T>(), reference::softmax<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_shape(), out[0]->get_shape(),
softmax->get_axes()); softmax->get_axes());
break;
} }
else if (node_op == "Sqrt") case OP_TYPEID::Sqrt:
{ {
reference::sqrt<T>( reference::sqrt<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
}
case OP_TYPEID::StopGradient:
{
// TODO: Throw a real unsupported_op when available
throw std::runtime_error("Unsupported op 'StopGradient'");
} }
else if (node_op == "Subtract") case OP_TYPEID::Subtract:
{ {
reference::subtract<T>(args[0]->get_data_ptr<T>(), reference::subtract<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
break;
} }
else if (node_op == "Sum") case OP_TYPEID::Sum:
{ {
const op::Sum* sum = static_cast<const op::Sum*>(&node); const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(args[0]->get_data_ptr<T>(), reference::sum<T>(args[0]->get_data_ptr<T>(),
...@@ -1016,18 +1104,21 @@ private: ...@@ -1016,18 +1104,21 @@ private:
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
sum->get_reduction_axes()); sum->get_reduction_axes());
break;
} }
else if (node_op == "Tan") case OP_TYPEID::Tan:
{ {
reference::tan<T>( reference::tan<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "Tanh") case OP_TYPEID::Tanh:
{ {
reference::tanh<T>( reference::tanh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count()); args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
break;
} }
else if (node_op == "TopK") case OP_TYPEID::TopK:
{ {
const op::TopK* topk = static_cast<const op::TopK*>(&node); const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (out[0]->get_element_type() == element::i64) if (out[0]->get_element_type() == element::i64)
...@@ -1057,11 +1148,7 @@ private: ...@@ -1057,11 +1148,7 @@ private:
throw ngraph_error("Unexpected type"); throw ngraph_error("Unexpected type");
} }
} }
else #pragma GCC diagnostic pop
{
std::stringstream ss;
ss << "unsupported op " << node_op;
throw ngraph_error(ss.str());
} }
} }
}; };
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::interpreter::OP_TYPEID::Abs},
// {"Acos", runtime::interpreter::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
// TODO: use unsupported_op when that is merged to master
throw runtime_error(m_node->description());
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a) a,
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::interpreter::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; }
ngraph::runtime::interpreter::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment