Unverified Commit 956f66ad authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Initial backend for non-IA CPUs (#2268)

* first cut at raspberry pi backend

* rename rpi to generic cpu

* disable cursed test
parent 20bd8bbc
......@@ -105,13 +105,14 @@ option(NGRAPH_GPU_ENABLE "Control the building of the GPU backend" FALSE)
option(NGRAPH_INTERPRETER_ENABLE "Control the building of the INTERPRETER backend" TRUE)
option(NGRAPH_NOP_ENABLE "Control the building of the NOP backend" TRUE)
option(NGRAPH_GPUH_ENABLE "Control the building of the Hybrid GPU backend" FALSE)
option(NGRAPH_GENERIC_CPU_ENABLE "Enable build nGraph for generic CPU backend" FALSE)
option(NGRAPH_DISTRIBUTED_ENABLE "Add distributed mode to the CPU backend" FALSE)
option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE)
option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE)
option(NGRAPH_DEX_ONLY "Build CPU DEX without codegen" FALSE)
option(NGRAPH_CODE_COVERAGE_ENABLE "Enable code coverage data collection" FALSE)
option(NGRAPH_LIB_VERSIONING_ENABLE "Enable shared library versioning" FALSE)
option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE)
option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE)
if (NGRAPH_GPUH_ENABLE)
set(NGRAPH_GPU_ENABLE TRUE)
......@@ -125,6 +126,7 @@ message(STATUS "NGRAPH_GPU_ENABLE: ${NGRAPH_GPU_ENABLE}")
message(STATUS "NGRAPH_INTERPRETER_ENABLE: ${NGRAPH_INTERPRETER_ENABLE}")
message(STATUS "NGRAPH_NOP_ENABLE: ${NGRAPH_NOP_ENABLE}")
message(STATUS "NGRAPH_GPUH_ENABLE: ${NGRAPH_GPUH_ENABLE}")
message(STATUS "NGRAPH_GENERIC_CPU_ENABLE: ${NGRAPH_GENERIC_CPU_ENABLE}")
message(STATUS "NGRAPH_DISTRIBUTED_ENABLE: ${NGRAPH_DISTRIBUTED_ENABLE}")
message(STATUS "NGRAPH_DEBUG_ENABLE: ${NGRAPH_DEBUG_ENABLE}")
message(STATUS "NGRAPH_ONNX_IMPORT_ENABLE: ${NGRAPH_ONNX_IMPORT_ENABLE}")
......
......@@ -30,7 +30,7 @@ using namespace ngraph;
void ngraph::default_logger_handler_func(const string& s)
{
cout << s << endl;
cout << s + "\n";
}
LogHelper::LogHelper(LOG_TYPE type,
......
......@@ -37,6 +37,10 @@ if (NGRAPH_GPUH_ENABLE)
add_subdirectory(gpuh)
endif()
if (NGRAPH_GENERIC_CPU_ENABLE)
add_subdirectory(generic_cpu)
endif()
if (NGRAPH_PLAIDML_ENABLE)
add_subdirectory(plaidml)
endif()
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
if (NGRAPH_GENERIC_CPU_ENABLE)
find_package(OpenMP)
if (OPENMP_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
add_library(gcpu_backend SHARED gcpu_backend.cpp node_wrapper.cpp)
if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(gcpu_backend PROPERTIES
VERSION ${NGRAPH_VERSION}
SOVERSION ${NGRAPH_API_VERSION})
endif()
target_link_libraries(gcpu_backend PRIVATE ngraph libeigen hybrid_base interpreter_backend)
target_compile_options(gcpu_backend PUBLIC -fopenmp)
set_target_properties(gcpu_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
install(TARGETS gcpu_backend
LIBRARY DESTINATION "${NGRAPH_INSTALL_LIB}"
ARCHIVE DESTINATION "${NGRAPH_INSTALL_LIB}"
)
endif()
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_backend.hpp"
#include <omp.h>
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/reshape_sinking.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
using descriptor::layout::DenseTensorLayout;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
}
extern "C" runtime::Backend* new_backend(const char* configuration_string)
{
return new runtime::gcpu::GCPUBackend();
}
runtime::gcpu::GCPUBackend::GCPUBackend()
{
}
runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_name_list)
: m_unsupported_op_name_list{unsupported_op_name_list.begin(), unsupported_op_name_list.end()}
{
}
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape)
{
return make_shared<runtime::HostTensor>(type, shape, this);
}
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer)
{
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, this);
}
runtime::Handle runtime::gcpu::GCPUBackend::compile(shared_ptr<Function> function)
{
FunctionInstance& instance = m_function_map[function];
if (!instance.m_is_compiled)
{
instance.m_is_compiled = true;
pass::Manager pass_manager;
// pass_manager.register_pass<pass::AnyAllReplacement>();
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.register_pass<pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::AlgebraicSimplification>();
// pass_manager.register_pass<pass::CoreFusion>();
// pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
// pass_manager.register_pass<pass::GetOutputElementElimination>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(get_alignment());
pass_manager.run_passes(function);
size_t memory_pool_size = function->get_temporary_pool_size();
instance.m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, get_alignment()));
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
instance.m_wrapped_nodes.emplace_back(node);
}
}
return function;
}
bool runtime::gcpu::GCPUBackend::call(shared_ptr<Function> function,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
auto fit = m_function_map.find(function);
if (fit == m_function_map.end())
{
throw runtime_error("compile() must be called before call().");
}
FunctionInstance& instance = fit->second;
if (!instance.m_is_compiled)
{
throw runtime_error("compile() must be called before call().");
}
// convert inputs to HostTensor
vector<void*> func_inputs;
vector<shared_ptr<runtime::HostTensor>> htv_inputs;
for (auto tensor : inputs)
{
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_inputs.push_back(static_cast<void*>(host_tensor->get_data_ptr()));
htv_inputs.push_back(host_tensor);
}
// convert outputs to HostTensor
vector<void*> func_outputs;
for (auto tensor : outputs)
{
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_outputs.push_back(static_cast<void*>(host_tensor->get_data_ptr()));
}
// map function params -> HostTensor
unordered_map<descriptor::Tensor*, void*> tensor_map;
size_t input_count = 0;
for (auto param : function->get_parameters())
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
descriptor::Tensor* tensor = param->get_output_tensor_ptr(i).get();
tensor_map.insert({tensor, func_inputs[input_count++]});
}
}
// map function outputs -> HostTensor
for (size_t output_count = 0; output_count < function->get_output_size(); ++output_count)
{
auto output = function->get_output_op(output_count);
if (!dynamic_pointer_cast<op::Result>(output))
{
throw ngraph_error("One of function's outputs isn't op::Result");
}
descriptor::Tensor* tensor = output->get_output_tensor_ptr(0).get();
tensor_map.insert({tensor, func_outputs[output_count]});
}
// for each ordered op in the graph
for (const NodeWrapper& wrapped : instance.m_wrapped_nodes)
{
const Node* op = &wrapped.get_node();
auto type_id = wrapped.get_typeid();
if (type_id == OP_TYPEID::Parameter)
{
continue;
}
if (type_id == OP_TYPEID::Constant)
{
const op::Constant* c = static_cast<const op::Constant*>(op);
descriptor::Tensor* tensor = op->get_output_tensor_ptr(0).get();
tensor_map.insert({tensor, const_cast<void*>(c->get_data_ptr())});
continue;
}
// get op inputs from map
vector<const void*> op_inputs;
for (const descriptor::Input& input : op->get_inputs())
{
descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get();
op_inputs.push_back(tensor_map.at(tensor));
}
// get op outputs from map or create
vector<void*> op_outputs;
vector<shared_ptr<runtime::HostTensor>> htv_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i)
{
descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get();
void* host_tensor = nullptr;
auto it = tensor_map.find(tensor);
if (it == tensor_map.end())
{
auto offset = op->get_output_tensor(i).get_pool_offset();
host_tensor = instance.get_temporary_pointer(offset);
tensor_map.insert({tensor, host_tensor});
}
else
{
host_tensor = it->second;
}
op_outputs.push_back(host_tensor);
htv_outputs.push_back(make_shared<runtime::HostTensor>(
tensor->get_element_type(), tensor->get_shape(), host_tensor, this));
}
// get op type
element::Type type;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wswitch-enum"
switch (type_id)
{
case OP_TYPEID::Convert:
case OP_TYPEID::Quantize:
case OP_TYPEID::Dequantize:
case OP_TYPEID::ArgMin:
case OP_TYPEID::ArgMax: type = op->get_input_element_type(0); break;
case OP_TYPEID::Equal:
case OP_TYPEID::Greater:
case OP_TYPEID::GreaterEq:
case OP_TYPEID::Less:
case OP_TYPEID::LessEq:
case OP_TYPEID::NotEqual:
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_input_element_type(1);
break;
case OP_TYPEID::TopK: type = op->get_output_element_type(1); break;
default: type = op->get_output_element_type(0); break;
}
#pragma GCC diagnostic pop
if (instance.m_performance_counters_enabled)
{
instance.m_timer_map[op].start();
}
generate_calls(type, wrapped, op_outputs, op_inputs, instance);
if (instance.m_performance_counters_enabled)
{
instance.m_timer_map[op].stop();
}
}
return true;
}
void runtime::gcpu::GCPUBackend::generate_calls(const element::Type& type,
const NodeWrapper& op,
const vector<void*>& outputs,
const vector<const void*>& inputs,
FunctionInstance& instance)
{
stringstream ss;
switch (type.get_type_enum())
{
case element::Type_t::boolean: op_engine<char>(op, outputs, inputs, instance); break;
case element::Type_t::f32: op_engine<float>(op, outputs, inputs, instance); break;
case element::Type_t::f64: op_engine<double>(op, outputs, inputs, instance); break;
case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs, instance); break;
case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs, instance); break;
case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs, instance); break;
case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs, instance); break;
case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs, instance); break;
case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs, instance); break;
case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs, instance); break;
case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs, instance); break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
ss << "unsupported element type " << type << " op " << op.get_node().get_name();
throw ngraph_error(ss.str());
}
}
void runtime::gcpu::GCPUBackend::enable_performance_data(shared_ptr<Function> func, bool enable)
{
FunctionInstance& instance = m_function_map[func];
instance.m_performance_counters_enabled = enable;
}
vector<runtime::PerformanceCounter>
runtime::gcpu::GCPUBackend::get_performance_data(shared_ptr<Function> func) const
{
vector<runtime::PerformanceCounter> rc;
const FunctionInstance& instance = m_function_map.at(func);
for (const pair<const Node*, stopwatch> p : instance.m_timer_map)
{
rc.emplace_back(p.first->get_name().c_str(),
p.second.get_total_microseconds(),
p.second.get_call_count());
}
return rc;
}
bool runtime::gcpu::GCPUBackend::is_supported(const Node& node) const
{
return m_unsupported_op_name_list.find(node.description()) == m_unsupported_op_name_list.end();
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <initializer_list>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/op/all.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/generic_cpu/kernel/broadcast.hpp"
#include "ngraph/runtime/generic_cpu/kernel/dot.hpp"
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/reduce.hpp"
#include "ngraph/runtime/reference/reduce_window.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
#endif
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
class GCPUBackend;
}
}
}
class ngraph::runtime::gcpu::GCPUBackend : public Backend
{
public:
GCPUBackend();
GCPUBackend(const std::vector<std::string>& unsupported_op_name_list);
GCPUBackend(const GCPUBackend&) = delete;
GCPUBackend(GCPUBackend&&) = delete;
GCPUBackend& operator=(const GCPUBackend&) = delete;
std::shared_ptr<Tensor>
create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape) override;
Handle compile(std::shared_ptr<Function> function) override;
bool call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
bool is_supported(const Node& node) const override;
private:
int get_alignment() const { return 64; }
class FunctionInstance
{
public:
bool m_is_compiled = false;
bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::shared_ptr<AlignedBuffer> m_temporary_memory;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
};
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
std::set<std::string> m_unsupported_op_name_list;
void generate_calls(const element::Type& type,
const NodeWrapper& op,
const std::vector<void*>& outputs,
const std::vector<const void*>& inputs,
FunctionInstance& instance);
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out,
const std::vector<const void*>& args,
FunctionInstance& instance)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wcovered-switch-default"
switch (node_wrapper.get_typeid())
{
case OP_TYPEID::Abs:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Acos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Add:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::All:
{
const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
all->get_reduction_axes());
break;
}
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])),
static_cast<T*>(out[0]),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))));
#endif
break;
}
case OP_TYPEID::And:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Any:
{
const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
any->get_reduction_axes());
break;
}
case OP_TYPEID::ArgMin:
{
const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node);
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
}
else if (element_type == element::i32)
{
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::ArgMax:
{
const op::ArgMax* argmax = static_cast<const op::ArgMax*>(&node);
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
}
else if (element_type == element::i32)
{
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::Asin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Atan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::AvgPool:
{
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
avg_pool->get_padding_below(),
avg_pool->get_padding_above(),
avg_pool->get_include_padding_in_avg_computation());
break;
}
case OP_TYPEID::GenerateMask:
{
if (instance.m_states.count(&node) == 0)
{
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
instance.m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
}
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]);
auto state = instance.m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training);
break;
}
case OP_TYPEID::GetOutputElement:
{
const op::GetOutputElement* get_output_element =
static_cast<const op::GetOutputElement*>(&node);
size_t n = get_output_element->get_n();
size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes);
break;
}
case OP_TYPEID::BatchNormTraining:
{
const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
node.get_input_shape(2));
break;
}
case OP_TYPEID::BatchNormInference:
{
const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<T*>(out[0]),
node.get_input_shape(2));
break;
}
case OP_TYPEID::BatchNormTrainingBackprop:
{
const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<const T*>(args[5]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
node.get_input_shape(2));
break;
}
case OP_TYPEID::AvgPoolBackprop:
{
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
apb->get_window_shape(),
apb->get_window_movement_strides(),
apb->get_padding_below(),
apb->get_padding_above(),
apb->get_include_padding_in_avg_computation());
break;
}
case OP_TYPEID::Broadcast:
{
const op::Broadcast* broadcast = static_cast<const op::Broadcast*>(&node);
gcpu::kernel::broadcast(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
broadcast->get_broadcast_axes());
break;
}
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Concat:
{
const op::Concat* concat = static_cast<const op::Concat*>(&node);
std::vector<const T*> in_args;
std::vector<Shape> in_shapes;
for (size_t i = 0; i < node.get_input_size(); i++)
{
in_args.push_back(static_cast<const T*>(args[i]));
in_shapes.push_back(node.get_input_shape(i));
}
reference::concat<T>(in_args,
static_cast<T*>(out[0]),
in_shapes,
node.get_output_shape(0),
concat->get_concatenation_axis());
break;
}
case OP_TYPEID::Constant:
{
// Constant is handled in the main loop
break;
}
case OP_TYPEID::ScalarConstantLike: break;
case OP_TYPEID::Convert:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
std::stringstream ss;
size_t element_count = shape_size(node.get_output_shape(0));
switch (type.get_type_enum())
{
case element::Type_t::boolean:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count);
break;
case element::Type_t::f32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count);
break;
case element::Type_t::f64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count);
break;
case element::Type_t::i8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count);
break;
case element::Type_t::i16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count);
break;
case element::Type_t::i32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count);
break;
case element::Type_t::i64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count);
break;
case element::Type_t::u8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count);
break;
case element::Type_t::u16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count);
break;
case element::Type_t::u32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count);
break;
case element::Type_t::u64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count);
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Convolution:
{
const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides(),
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
{
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
break;
}
case OP_TYPEID::ConvolutionBackpropData:
{
// Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(static_cast<const T*>(args[1]),
static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
node.get_input_shape(0),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
break;
}
case OP_TYPEID::Cos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Cosh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Dequantize:
{
const op::Dequantize* dequantize = static_cast<const op::Dequantize*>(&node);
auto type = dequantize->get_element_type();
if (type == element::f32)
{
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const float*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<float*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else if (type == element::f64)
{
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const double*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<double*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op Dequantize";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Divide:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Dot:
{
const op::Dot* dot = static_cast<const op::Dot*>(&node);
gcpu::kernel::dot(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
dot->get_reduction_axes_count());
break;
}
case OP_TYPEID::EmbeddingLookup:
{
const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node);
auto type = embed->get_argument(0)->get_element_type();
size_t element_count = shape_size(embed->get_argument(0)->get_shape());
if (type == element::f32)
{
reference::embedding<T, float>(static_cast<const float*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else if (type == element::f64)
{
reference::embedding<T, double>(static_cast<const double*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else if (type == element::i32)
{
reference::embedding<T, int>(static_cast<const int*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else if (type == element::i64)
{
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else
{
throw ngraph_error(std::string("Unsupported index type ") + type.c_type_string() +
std::string("in EmbeddingLookup"));
}
break;
}
case OP_TYPEID::Equal:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Exp:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Floor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::FunctionCall:
{
std::shared_ptr<Function> function = node.get_functions()[0];
std::vector<std::shared_ptr<runtime::Tensor>> outputs;
for (size_t i = 0; i < function->get_output_size(); i++)
{
element::Type et = function->get_output_element_type(i);
Shape shape = function->get_output_shape(i);
auto host_tensor = std::make_shared<HostTensor>(et, shape, out[i], this);
outputs.push_back(std::static_pointer_cast<runtime::Tensor>(host_tensor));
}
std::vector<std::shared_ptr<runtime::Tensor>> inputs;
auto parameters = function->get_parameters();
for (size_t i = 0; i < parameters.size(); i++)
{
auto parameter = parameters[i];
element::Type et = parameter->get_element_type();
Shape shape = parameter->get_shape();
auto host_tensor =
std::make_shared<HostTensor>(et, shape, const_cast<void*>(args[i]), this);
inputs.push_back(std::static_pointer_cast<runtime::Tensor>(host_tensor));
}
auto handle = compile(function);
call(handle, outputs, inputs);
break;
}
case OP_TYPEID::Greater:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::GreaterEq:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Less:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::LessEq:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Log:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
lrn->get_alpha(),
lrn->get_beta(),
lrn->get_bias(),
lrn->get_nsize());
break;
}
case OP_TYPEID::Max:
{
const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
max->get_reduction_axes());
break;
}
case OP_TYPEID::Maximum:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::MaxPool:
{
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
max_pool->get_window_shape(),
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
break;
}
case OP_TYPEID::MaxPoolBackprop:
{
const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
node.get_output_shape(0),
max_pool_backprop->get_window_shape(),
max_pool_backprop->get_window_movement_strides(),
max_pool_backprop->get_padding_below(),
max_pool_backprop->get_padding_above());
break;
}
case OP_TYPEID::Min:
{
const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
min->get_reduction_axes());
break;
}
case OP_TYPEID::Minimum:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Multiply:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Negative:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::NotEqual:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::OneHot:
{
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
oh->get_one_hot_axis());
break;
}
case OP_TYPEID::Or:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Parameter: break;
case OP_TYPEID::Pad:
{
const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
pad->get_padding_below(),
pad->get_padding_above(),
pad->get_padding_interior());
break;
}
case OP_TYPEID::Power:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Product:
{
const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
product->get_reduction_axes());
break;
}
case OP_TYPEID::Quantize:
{
const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
auto type = quantize->get_element_type();
if (type == element::u8)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const uint8_t*>(args[2]),
static_cast<uint8_t*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i8)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int8_t*>(args[2]),
static_cast<int8_t*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i32)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int32_t*>(args[2]),
static_cast<int32_t*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op Quantize";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Reduce:
{
const op::Reduce* reduce = static_cast<const op::Reduce*>(&node);
std::shared_ptr<Function> reduction_function = reduce->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(0).get_element_type(), Shape{}, &x));
auto ty = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(1).get_element_type(), Shape{}, &y));
auto tr = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_output_element_type(0), Shape{}));
auto handle = compile(reduction_function);
call(handle, {tr}, {tx, ty});
return *(tr->get_data_ptr<T>());
};
reference::reduce(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
reduce->get_reduction_axes(),
f);
break;
}
case OP_TYPEID::ReduceWindow:
{
const op::ReduceWindow* reduce_window = static_cast<const op::ReduceWindow*>(&node);
std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(0).get_element_type(), Shape{}, &x));
auto ty = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(1).get_element_type(), Shape{}, &y));
auto tr = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_output_element_type(0), Shape{}));
auto handle = compile(reduction_function);
call(handle, {tr}, {tx, ty});
return *(tr->get_data_ptr<T>());
};
reference::reduce_window(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
f,
reduce_window->get_window_shape(),
reduce_window->get_window_movement_strides());
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::ReluBackprop:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::ReplaceSlice:
{
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
node.get_output_shape(0));
break;
}
case OP_TYPEID::Reshape:
{
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
gcpu::kernel::reshape<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
reshape->get_input_order(),
node.get_output_shape(0));
break;
}
case OP_TYPEID::Result:
{
const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
shape_size(res->get_shape()));
break;
}
case OP_TYPEID::Reverse:
{
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
reverse->get_reversed_axes());
break;
}
case OP_TYPEID::ReverseSequence:
{
const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
if (node.get_input_element_type(1) == element::i32)
{
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
reverse->get_batch_axis(),
reverse->get_sequence_axis(),
static_cast<const int32_t*>(args[1]));
}
else
{
throw ngraph_error("only int32 indices are supported");
}
break;
}
case OP_TYPEID::Select:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::SelectAndScatter:
{
const ngraph::op::SelectAndScatter* select_and_scatter =
static_cast<const ngraph::op::SelectAndScatter*>(&node);
std::shared_ptr<ngraph::Function> selection_function =
select_and_scatter->get_functions()[0];
std::function<bool(T, T)> f_selection = [this, &node, selection_function](T x,
T y) -> bool {
auto tx = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(0).get_element_type(), Shape{}, &x));
auto ty = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(1).get_element_type(), Shape{}, &y));
auto tr =
std::static_pointer_cast<HostTensor>(create_tensor(element::boolean, Shape{}));
auto handle = compile(selection_function);
call(handle, {tr}, {tx, ty});
return *(tr->get_data_ptr<char>());
};
std::shared_ptr<ngraph::Function> scatter_function =
select_and_scatter->get_functions()[1];
std::function<T(T, T)> f_scatter = [this, &node, scatter_function](T x, T y) -> T {
auto tx = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(0).get_element_type(), Shape{}, &x));
auto ty = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_inputs().at(1).get_element_type(), Shape{}, &y));
auto tr = std::static_pointer_cast<HostTensor>(
create_tensor(node.get_output_element_type(0), Shape{}));
auto handle = compile(scatter_function);
call(handle, {tr}, {tx, ty});
return *(tr->get_data_ptr<T>());
};
reference::select_and_scatter<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
f_selection,
f_scatter,
select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides());
break;
}
case OP_TYPEID::ShapeOf:
{
reference::shape_of(node.get_input_shape(0), static_cast<uint64_t*>(out[0]));
break;
}
case OP_TYPEID::Sigmoid:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::SigmoidBackprop:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Sign:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Sin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Sinh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Slice:
{
const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
node.get_output_shape(0));
break;
}
case OP_TYPEID::Softmax:
{
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_output_shape(0),
softmax->get_axes());
break;
}
case OP_TYPEID::Sqrt:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
}
case OP_TYPEID::Subtract:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Sum:
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
sum->get_reduction_axes());
break;
}
case OP_TYPEID::Tan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Tanh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::TopK:
{
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64)
{
reference::topk<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
static_cast<T*>(out[1]),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
}
else if (node.get_output_element_type(0) == element::i32)
{
reference::topk<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
static_cast<T*>(out[1]),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
default: throw unsupported_op("Unsupported op '" + node.description() + "'");
#pragma GCC diagnostic pop
}
}
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <Eigen/Dense>
#include <cmath>
#include <omp.h>
#include <utility>
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
std::tuple<size_t, size_t> get_start_finish(size_t size)
{
const size_t nthreads = omp_get_num_threads();
const size_t ithread = omp_get_thread_num();
const size_t start = ithread * size / nthreads;
const size_t finish = (ithread + 1) * size / nthreads;
return std::make_tuple(start, finish);
}
template <typename T>
void broadcast_2d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[2];
size_t* out_index =
(broadcast_axes.find(0) == broadcast_axes.end() ? &index[0] : &index[1]);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
out[index[0] * out_shape[1] + index[1]] = in[*out_index];
}
}
}
// #define PARALLEL
template <typename T>
void broadcast_3d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
#ifdef PARALLEL
#pragma omp parallel
#endif
{
size_t start;
size_t finish;
#ifdef PARALLEL
std::tie(start, finish) = get_start_finish(out_shape[0]);
#else
start = 0;
finish = out_shape[0];
#endif
size_t index[3];
size_t* out_index = 0;
for (size_t i = 0; i < 3; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = start; index[0] < finish; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
out[index[0] * out_shape[1] * out_shape[2] +
index[1] * out_shape[2] + index[2]] = in[*out_index];
}
}
}
}
}
template <typename T>
void broadcast_4d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[4];
size_t* out_index = 0;
for (size_t i = 0; i < 4; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] +
index[1] * out_shape[2] * out_shape[3] +
index[2] * out_shape[3] + index[3]] = in[*out_index];
}
}
}
}
}
template <typename T>
void broadcast(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
if (in_shape.size() == 0)
{
for (size_t i = 0; i < shape_size(out_shape); ++i)
{
out[i] = in[0];
}
}
else if (in_shape.size() == 1)
{
switch (out_shape.size())
{
case 2:
broadcast_2d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 3:
broadcast_3d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 4:
broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
}
}
else
{
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <Eigen/Dense>
#include <cmath>
#include <omp.h>
#include <utility>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void dot(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t reduction_axes_count)
{
if (arg0_shape.size() == 2 && arg1_shape.size() == 2 && out_shape.size() == 2)
{
Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
a0(const_cast<T*>(arg0), arg0_shape[0], arg0_shape[1]);
Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
a1(const_cast<T*>(arg1), arg1_shape[0], arg1_shape[1]);
Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
o(const_cast<T*>(out), out_shape[0], out_shape[1]);
o = a0 * a1;
}
else
{
// Get the sizes of the dot axes. It's easiest to pull them from arg1 because they're
// right up front.
Shape dot_axis_sizes(reduction_axes_count);
std::copy(arg1_shape.begin(),
arg1_shape.begin() + reduction_axes_count,
dot_axis_sizes.begin());
CoordinateTransform arg0_transform(arg0_shape);
CoordinateTransform arg1_transform(arg1_shape);
CoordinateTransform output_transform(out_shape);
// Create coordinate transforms for arg0 and arg1 that throw away the dotted axes.
size_t arg0_projected_rank = arg0_shape.size() - reduction_axes_count;
size_t arg1_projected_rank = arg1_shape.size() - reduction_axes_count;
Shape arg0_projected_shape(arg0_projected_rank);
std::copy(arg0_shape.begin(),
arg0_shape.begin() + arg0_projected_rank,
arg0_projected_shape.begin());
Shape arg1_projected_shape(arg1_projected_rank);
std::copy(arg1_shape.begin() + reduction_axes_count,
arg1_shape.end(),
arg1_projected_shape.begin());
CoordinateTransform arg0_projected_transform(arg0_projected_shape);
CoordinateTransform arg1_projected_transform(arg1_projected_shape);
// Create a coordinate transform that allows us to iterate over all possible values
// for the dotted axes.
CoordinateTransform dot_axes_transform(dot_axis_sizes);
for (const Coordinate& arg0_projected_coord : arg0_projected_transform)
{
for (const Coordinate& arg1_projected_coord : arg1_projected_transform)
{
// The output coordinate is just the concatenation of the projected coordinates.
Coordinate out_coord(arg0_projected_coord.size() +
arg1_projected_coord.size());
auto out_coord_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
out_coord.begin());
std::copy(arg1_projected_coord.begin(),
arg1_projected_coord.end(),
out_coord_it);
// Zero out to start the sum.
T sum = 0;
size_t out_index = output_transform.index(out_coord);
// Walk along the dotted axes.
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
for (const Coordinate& dot_axis_positions : dot_axes_transform)
{
// In order to find the points to multiply together, we need to inject our current
// positions along the dotted axes back into the projected arg0 and arg1 coordinates.
std::copy(dot_axis_positions.begin(),
dot_axis_positions.end(),
arg0_it);
auto arg1_it = std::copy(dot_axis_positions.begin(),
dot_axis_positions.end(),
arg1_coord.begin());
std::copy(arg1_projected_coord.begin(),
arg1_projected_coord.end(),
arg1_it);
// Multiply and add to the sum.
sum += arg0[arg0_transform.index(arg0_coord)] *
arg1[arg1_transform.index(arg1_coord)];
}
// Write the sum back.
out[out_index] = sum;
}
}
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void reshape_in0(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
*out = *in;
}
template <typename T>
void reshape_in1(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[1];
size_t in_index[1];
size_t* map_index[1];
for (size_t i = 0; i < 1; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
*out++ = in[*map_index[0]];
}
}
template <typename T>
void reshape_in2(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[2];
size_t in_index[2];
size_t* map_index[2];
for (size_t i = 0; i < 2; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
*out++ = in[*map_index[0] * in_shape[1] + *map_index[1]];
}
}
}
template <typename T>
void reshape_in3(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[3];
size_t in_index[3];
size_t* map_index[3];
for (size_t i = 0; i < 3; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] +
*map_index[1] * in_shape[2] + *map_index[2]];
}
}
}
}
template <typename T>
void reshape_in4(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[4];
size_t in_index[4];
size_t* map_index[4];
for (size_t i = 0; i < 4; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
*out++ =
in[*map_index[0] * in_shape[1] * in_shape[2] * in_shape[3] +
*map_index[1] * in_shape[2] * in_shape[3] +
*map_index[2] * in_shape[3] + *map_index[3]];
}
}
}
}
}
template <typename T>
void reshape_in5(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[5];
size_t in_index[5];
size_t* map_index[5];
for (size_t i = 0; i < 5; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
{
*out++ = in[*map_index[0] * in_shape[1] * in_shape[2] *
in_shape[3] * in_shape[4] +
*map_index[1] * in_shape[2] * in_shape[3] *
in_shape[4] +
*map_index[2] * in_shape[3] * in_shape[4] +
*map_index[3] * in_shape[4] + *map_index[4]];
}
}
}
}
}
}
template <typename T>
void reshape_in6(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
size_t size[6];
size_t in_index[6];
size_t* map_index[6];
for (size_t i = 0; i < 6; i++)
{
size[i] = in_shape[in_axis_order[i]];
map_index[in_axis_order[i]] = &in_index[i];
}
for (in_index[0] = 0; in_index[0] < size[0]; ++in_index[0])
{
for (in_index[1] = 0; in_index[1] < size[1]; ++in_index[1])
{
for (in_index[2] = 0; in_index[2] < size[2]; ++in_index[2])
{
for (in_index[3] = 0; in_index[3] < size[3]; ++in_index[3])
{
for (in_index[4] = 0; in_index[4] < size[4]; ++in_index[4])
{
for (in_index[5] = 0; in_index[5] < size[5]; ++in_index[5])
{
*out++ =
in[*map_index[0] * in_shape[1] * in_shape[2] *
in_shape[3] * in_shape[4] * in_shape[5] +
*map_index[1] * in_shape[2] * in_shape[3] *
in_shape[4] * in_shape[5] +
*map_index[2] * in_shape[3] * in_shape[4] *
in_shape[5] +
*map_index[3] * in_shape[4] * in_shape[5] +
*map_index[4] * in_shape[5] + *map_index[5]];
}
}
}
}
}
}
}
template <typename T>
void reshape(const T* in,
T* out,
const Shape& in_shape,
const AxisVector& in_axis_order,
const Shape& out_shape)
{
switch (in_shape.size())
{
case 0: reshape_in0<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 1: reshape_in1<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 2: reshape_in2<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 3: reshape_in3<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default:
NGRAPH_INFO << "reference::reshape";
reference::reshape(in, out, in_shape, in_axis_order, out_shape);
break;
}
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void result(const T* arg, T* out, size_t count)
{
memcpy(out, arg, sizeof(T) * count);
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::gcpu::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::gcpu::OP_TYPEID::Abs},
// {"Acos", runtime::gcpu::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a, b) {#a, runtime::gcpu::OP_TYPEID::a},
static unordered_map<string, runtime::gcpu::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + m_node->description() + "'");
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a, b) a,
enum class ngraph::runtime::gcpu::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::gcpu::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; }
ngraph::runtime::gcpu::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
......@@ -20,6 +20,7 @@ batch_norm_inference_f64
batch_norm_training_0eps_f64
batch_norm_one_output
batch_norm_three_outputs
batch_norm_bprop_n4c3h2w2
dequantize
dequantize_axes
dequantize_int32
......
......@@ -45,6 +45,9 @@ endif()
if (NGRAPH_PLAIDML_ENABLE)
target_link_libraries(nbench plaidml_backend)
endif()
if (NGRAPH_GENERIC_CPU_ENABLE)
target_link_libraries(nbench gcpu_backend)
endif()
if (NGRAPH_DISTRIBUTED_ENABLE)
target_compile_definitions(nbench PRIVATE NGRAPH_DISTRIBUTED)
......
......@@ -15,7 +15,9 @@
//*****************************************************************************
#include <random>
#ifdef __i386__
#include <xmmintrin.h>
#endif
#include "benchmark.hpp"
#include "ngraph/file_util.hpp"
......@@ -107,53 +109,20 @@ void init_real_tv(shared_ptr<runtime::Tensor> tv, T min, T max)
static void random_init(shared_ptr<runtime::Tensor> tv)
{
element::Type et = tv->get_element_type();
if (et == element::boolean)
{
init_int_tv<char>(tv, 0, 1);
}
else if (et == element::f32)
{
init_real_tv<float>(tv, -1, 1);
}
else if (et == element::f64)
{
init_real_tv<double>(tv, -1, 1);
}
else if (et == element::i8)
{
init_int_tv<int8_t>(tv, -1, 1);
}
else if (et == element::i16)
{
init_int_tv<int16_t>(tv, -1, 1);
}
else if (et == element::i32)
{
init_int_tv<int32_t>(tv, 0, 1);
}
else if (et == element::i64)
{
init_int_tv<int64_t>(tv, -1, 1);
}
else if (et == element::u8)
{
init_int_tv<uint8_t>(tv, 0, 1);
}
else if (et == element::u16)
{
init_int_tv<uint16_t>(tv, 0, 1);
}
else if (et == element::u32)
{
init_int_tv<uint32_t>(tv, 0, 1);
}
else if (et == element::u64)
{
init_int_tv<uint64_t>(tv, 0, 1);
}
else
{
throw runtime_error("unsupported type");
switch (et.get_type_enum())
{
case element::Type_t::boolean: init_int_tv<char>(tv, 0, 1); break;
case element::Type_t::f32: init_real_tv<float>(tv, -1, 1); break;
case element::Type_t::f64: init_real_tv<double>(tv, -1, 1); break;
case element::Type_t::i8: init_int_tv<int8_t>(tv, -1, 1); break;
case element::Type_t::i16: init_int_tv<int16_t>(tv, -1, 1); break;
case element::Type_t::i32: init_int_tv<int32_t>(tv, 0, 1); break;
case element::Type_t::i64: init_int_tv<int64_t>(tv, -1, 1); break;
case element::Type_t::u8: init_int_tv<uint8_t>(tv, 0, 1); break;
case element::Type_t::u16: init_int_tv<uint16_t>(tv, 0, 1); break;
case element::Type_t::u32: init_int_tv<uint32_t>(tv, 0, 1); break;
case element::Type_t::u64: init_int_tv<uint64_t>(tv, 0, 1); break;
default: throw runtime_error("unsupported type");
}
}
......
......@@ -106,6 +106,10 @@ if (NGRAPH_PLAIDML_ENABLE)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} PlaidML)
endif()
if (NGRAPH_GENERIC_CPU_ENABLE)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} GCPU)
endif()
add_subdirectory(models)
add_subdirectory(files)
add_subdirectory(util)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment