Unverified Commit 334a55fa authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into rearhart/plaidml

parents 5cfe1075 47342339
...@@ -18,10 +18,10 @@ include(ExternalProject) ...@@ -18,10 +18,10 @@ include(ExternalProject)
# Includes blas 3.8.0 in mkldnn # Includes blas 3.8.0 in mkldnn
set(NGRAPH_MKLDNN_SHORT_VERSION 0) set(NGRAPH_MKLDNN_SHORT_VERSION 0)
set(NGRAPH_MKLDNN_FULL_VERSION 0.19.0.0) set(NGRAPH_MKLDNN_FULL_VERSION 0.20.0.0)
set(NGRAPH_MKLDNN_VERSION "v0.19") set(NGRAPH_MKLDNN_VERSION "v0.20")
set(NGRAPH_MKLDNN_SUB_VERSION "2019.0.5.20190502") set(NGRAPH_MKLDNN_SUB_VERSION "2019.0.5.20190502")
set(NGRAPH_MKLDNN_GIT_TAG "027de76") set(NGRAPH_MKLDNN_GIT_TAG "v0.20")
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
# Fetch and install MKL-DNN # Fetch and install MKL-DNN
......
...@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644 ...@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644
set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS}) set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS})
target_include_directories(${LIB_NAME} PUBLIC target_include_directories(${LIB_NAME} PUBLIC
diff --git a/src/cpu/jit_avx512_common_conv_kernel.cpp b/src/cpu/jit_avx512_common_conv_kernel.cpp
index 1bb98fa43..b8b54401f 100644
--- a/src/cpu/jit_avx512_common_conv_kernel.cpp
+++ b/src/cpu/jit_avx512_common_conv_kernel.cpp
@@ -3055,7 +3055,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel_3d() {
void jit_avx512_common_conv_bwd_weights_kernel_f32
::compute_oh_loop_common()
{
- assert(jcp.harness == harness_mb_reduction);
+ assert(one_of(jcp.harness, harness_mb_reduction, harness_3d_reduction));
int b_pad = jcp.b_pad;
int t_pad = jcp.t_pad;
bool is_dilated = jcp.dilate_h != 0;
pytest pytest
tox tox
pydocstyle==3.0.0
flake8 flake8
flake8-commas flake8-commas
flake8-comprehensions flake8-comprehensions
......
...@@ -49,7 +49,8 @@ public: ...@@ -49,7 +49,8 @@ public:
} }
}; };
std::unique_ptr<ngraph::runtime::Allocator> ngraph::runtime::create_default_allocator() ngraph::runtime::Allocator* ngraph::runtime::get_default_allocator()
{ {
return std::unique_ptr<DefaultAllocator>(new DefaultAllocator()); static std::unique_ptr<DefaultAllocator> allocator(new DefaultAllocator());
return allocator.get();
} }
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
class DefaultAllocator; class DefaultAllocator;
/// \brief Create a default allocator that calls into system /// \brief Create a default allocator that calls into system
/// allocation libraries /// allocation libraries
std::unique_ptr<Allocator> create_default_allocator(); ngraph::runtime::Allocator* get_default_allocator();
} }
} }
......
...@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator() ...@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator()
{ {
if (!m_allocator) if (!m_allocator)
{ {
m_allocator = create_default_allocator(); return runtime::get_default_allocator();
} }
return m_allocator.get(); return m_allocator.get();
} }
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# ****************************************************************************** # ******************************************************************************
if (NGRAPH_GENERIC_CPU_ENABLE) if (NGRAPH_GENERIC_CPU_ENABLE)
find_package(OpenMP) # find_package(OpenMP)
if (OPENMP_FOUND) # if (OPENMP_FOUND)
add_compile_options(${OpenMP_CXX_FLAGS}) # add_compile_options(${OpenMP_CXX_FLAGS})
endif() # endif()
add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp) add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp)
if(NGRAPH_LIB_VERSIONING_ENABLE) if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(gcpu_backend PROPERTIES set_target_properties(gcpu_backend PROPERTIES
......
...@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam ...@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type, shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape) const Shape& shape)
{ {
return make_shared<runtime::HostTensor>(type, shape, this); return make_shared<runtime::HostTensor>(type, shape);
} }
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type, shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape, const Shape& shape,
void* memory_pointer) void* memory_pointer)
{ {
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, this); return make_shared<runtime::HostTensor>(type, shape, memory_pointer);
} }
shared_ptr<runtime::Executable> shared_ptr<runtime::Executable>
......
...@@ -15,17 +15,22 @@ ...@@ -15,17 +15,22 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp" #include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp" #include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp" #include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout; ...@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
runtime::gcpu::GCPUExecutable::GCPUExecutable(const shared_ptr<Function>& function, runtime::gcpu::GCPUExecutable::GCPUExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection) bool enable_performance_collection)
: m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection}
{ {
{ m_function = clone_function(*function);
m_is_compiled = true;
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function); pass_manager.run_passes(m_function);
for (const shared_ptr<Node>& node : function->get_ordered_ops()) for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{ {
m_wrapped_nodes.emplace_back(node); m_wrapped_nodes.emplace_back(node);
} }
set_parameters_and_results(*m_function);
}
runtime::gcpu::GCPUExecutable::GCPUExecutable(const std::string& model_string)
: m_is_compiled{true}
, m_performance_counters_enabled{false}
{
m_function = deserialize(model_string);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
} }
set_parameters_and_results(*function); set_parameters_and_results(*m_function);
} }
bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs, bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
...@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{ {
for (size_t i = 0; i < param->get_output_size(); ++i) for (size_t i = 0; i < param->get_output_size(); ++i)
{ {
descriptor::Tensor* tensor = param->get_output_tensor_ptr(i).get(); descriptor::Tensor* tensor = &param->output(i).get_tensor();
tensor_map.insert({tensor, func_inputs[input_count++]}); tensor_map.insert({tensor, func_inputs[input_count++]});
} }
} }
...@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{ {
throw ngraph_error("One of function's outputs isn't op::Result"); throw ngraph_error("One of function's outputs isn't op::Result");
} }
descriptor::Tensor* tensor = output->get_output_tensor_ptr(0).get(); descriptor::Tensor* tensor = &output->output(0).get_tensor();
tensor_map.insert({tensor, func_outputs[output_count]}); tensor_map.insert({tensor, func_outputs[output_count]});
} }
// for each ordered op in the graph // for each ordered op in the graph
for (const NodeWrapper& wrapped : m_wrapped_nodes) for (const NodeWrapper& wrapped : m_wrapped_nodes)
{ {
const Node* op = &wrapped.get_node(); auto op = wrapped.get_node();
auto type_id = wrapped.get_typeid(); auto type_id = wrapped.get_typeid();
if (type_id == OP_TYPEID::Parameter) if (type_id == OP_TYPEID::Parameter)
{ {
...@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
// get op inputs from map // get op inputs from map
vector<shared_ptr<HostTensor>> op_inputs; vector<shared_ptr<HostTensor>> op_inputs;
for (const descriptor::Input& input : op->get_inputs()) for (auto input : op->inputs())
{ {
descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get(); descriptor::Tensor* tensor = &input.get_tensor();
op_inputs.push_back(tensor_map.at(tensor)); op_inputs.push_back(tensor_map.at(tensor));
} }
...@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
vector<shared_ptr<HostTensor>> op_outputs; vector<shared_ptr<HostTensor>> op_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i) for (size_t i = 0; i < op->get_output_size(); ++i)
{ {
descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get(); descriptor::Tensor* tensor = &op->output(i).get_tensor();
shared_ptr<HostTensor> host_tensor; shared_ptr<HostTensor> host_tensor;
auto it = tensor_map.find(tensor); auto it = tensor_map.find(tensor);
if (it == tensor_map.end()) if (it == tensor_map.end())
{ {
const Shape& shape = op->get_output_shape(i); const Shape& shape = op->get_output_shape(i);
const element::Type& type = op->get_output_element_type(i); const element::Type& type = op->get_output_element_type(i);
string name = op->get_output_tensor(i).get_name(); string name = op->output(i).get_tensor().get_name();
host_tensor = make_shared<runtime::HostTensor>(type, shape, name); host_tensor = make_shared<runtime::HostTensor>(type, shape, name);
tensor_map.insert({tensor, host_tensor}); tensor_map.insert({tensor, host_tensor});
} }
...@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
} }
if (m_nan_check_enabled) if (m_nan_check_enabled)
{ {
perform_nan_check(op_outputs, op); perform_nan_check(op_outputs, op.get());
} }
} }
...@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type, void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs, const vector<shared_ptr<HostTensor>>& out,
const vector<shared_ptr<HostTensor>>& inputs) const vector<shared_ptr<HostTensor>>& in)
{ {
vector<void*> out;
vector<const void*> in;
for (auto t : outputs)
{
out.push_back(t->get_data_ptr());
}
for (auto t : inputs)
{
in.push_back(t->get_data_ptr());
}
stringstream ss; stringstream ss;
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
...@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type, ...@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
ss << "unsupported element type " << type << " op " << op.get_node().get_name(); case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
throw ngraph_error(ss.str()); throw ngraph_error(ss.str());
} }
} }
...@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable) ...@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
vector<runtime::PerformanceCounter> runtime::gcpu::GCPUExecutable::get_performance_data() const vector<runtime::PerformanceCounter> runtime::gcpu::GCPUExecutable::get_performance_data() const
{ {
vector<runtime::PerformanceCounter> rc; vector<runtime::PerformanceCounter> rc;
for (const pair<const Node*, stopwatch> p : m_timer_map) for (const pair<shared_ptr<const Node>, stopwatch> p : m_timer_map)
{ {
rc.emplace_back(p.first->get_name().c_str(), rc.emplace_back(p.first, p.second.get_total_microseconds(), p.second.get_call_count());
p.second.get_total_microseconds(),
p.second.get_call_count());
} }
return rc; return rc;
} }
...@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho ...@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
arg_number++; arg_number++;
} }
} }
void runtime::gcpu::GCPUExecutable::save(ostream& out)
{
cpio::Writer writer(out);
string si = "INTERPRETER Save File 1.0";
writer.write("save_info", si.data(), si.size());
string model = serialize(m_function, 0);
writer.write("model", model.data(), model.size());
}
...@@ -17,24 +17,31 @@ ...@@ -17,24 +17,31 @@
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/op/all.hpp" #include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp" #include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp" #include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
...@@ -48,11 +55,14 @@ ...@@ -48,11 +55,14 @@
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -64,7 +74,6 @@ ...@@ -64,7 +74,6 @@
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp" #include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp" #include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp" #include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
...@@ -77,7 +86,9 @@ ...@@ -77,7 +86,9 @@
#include "ngraph/runtime/reference/asin.hpp" #include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp" #include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp" #include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp" #include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp" #include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp" #include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
...@@ -89,8 +100,10 @@ ...@@ -89,8 +100,10 @@
#include "ngraph/runtime/reference/cosh.hpp" #include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp" #include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp" #include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp" #include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp" #include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp" #include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp" #include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp" #include "ngraph/runtime/reference/gather.hpp"
...@@ -117,14 +130,17 @@ ...@@ -117,14 +130,17 @@
#include "ngraph/runtime/reference/power.hpp" #include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp" #include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp" #include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp" #include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp" #include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp" #include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp" #include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
...@@ -134,6 +150,7 @@ ...@@ -134,6 +150,7 @@
#include "ngraph/runtime/reference/softmax.hpp" #include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp" #include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp" #include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp" #include "ngraph/runtime/reference/topk.hpp"
...@@ -154,6 +171,8 @@ namespace ngraph ...@@ -154,6 +171,8 @@ namespace ngraph
class ngraph::runtime::gcpu::GCPUExecutable : public Executable class ngraph::runtime::gcpu::GCPUExecutable : public Executable
{ {
friend class GCPUBackend;
public: public:
GCPUExecutable(const std::shared_ptr<Function>& function, GCPUExecutable(const std::shared_ptr<Function>& function,
bool enable_performance_collection = false); bool enable_performance_collection = false);
...@@ -161,20 +180,25 @@ public: ...@@ -161,20 +180,25 @@ public:
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs, bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override; const std::vector<std::shared_ptr<Tensor>>& intputs) override;
virtual void save(std::ostream& output_stream) override;
void set_nan_check(bool enable); void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override; std::vector<PerformanceCounter> get_performance_data() const override;
private: private:
GCPUExecutable(const std::string& model_string);
int get_alignment() const { return 64; }
bool m_is_compiled = false; bool m_is_compiled = false;
bool m_nan_check_enabled = false; bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false; bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map; std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes; std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states; std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::set<std::string> m_unsupported_op_name_list; std::set<std::string> m_unsupported_op_name_list;
int get_alignment() const { return 64; }
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&, static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr); const Node* op = nullptr);
...@@ -185,11 +209,10 @@ private: ...@@ -185,11 +209,10 @@ private:
template <typename T> template <typename T>
void op_engine(const NodeWrapper& node_wrapper, void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out, const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<const void*>& args) const std::vector<std::shared_ptr<HostTensor>>& args)
{ {
const Node& node = node_wrapper.get_node(); const Node& node = *node_wrapper.get_node();
std::string node_op = node.description();
// We want to check that every OP_TYPEID enumeration is included in the list. // We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration // These GCC flags enable compile-time checking so that if an enumeration
...@@ -206,30 +229,30 @@ private: ...@@ -206,30 +229,30 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>( reference::abs<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Acos: case OP_TYPEID::Acos:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>( reference::acos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]), reference::add<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
{ {
const op::All* all = static_cast<const op::All*>(&node); const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]), reference::all(args[0]->get_data_ptr<const char>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
all->get_reduction_axes()); all->get_reduction_axes());
...@@ -237,26 +260,29 @@ private: ...@@ -237,26 +260,29 @@ private:
} }
case OP_TYPEID::AllReduce: case OP_TYPEID::AllReduce:
{ {
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])), const ngraph::op::AllReduce* allreduce =
static_cast<T*>(out[0]), static_cast<const ngraph::op::AllReduce*>(&node);
node.get_input_element_type(0), reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
allreduce->get_reduce_type(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))));
break; break;
} }
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]), reference::logical_and(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
const op::Any* any = static_cast<const op::Any*>(&node); const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]), reference::any(args[0]->get_data_ptr<const char>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
any->get_reduction_axes()); any->get_reduction_axes());
...@@ -268,16 +294,16 @@ private: ...@@ -268,16 +294,16 @@ private:
auto element_type = node.get_output_element_type(0); auto element_type = node.get_output_element_type(0);
if (element_type == element::i64) if (element_type == element::i64)
{ {
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]), reference::argmin<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmin->get_reduction_axis()); argmin->get_reduction_axis());
} }
else if (element_type == element::i32) else if (element_type == element::i32)
{ {
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]), reference::argmin<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmin->get_reduction_axis()); argmin->get_reduction_axis());
...@@ -294,16 +320,16 @@ private: ...@@ -294,16 +320,16 @@ private:
auto element_type = node.get_output_element_type(0); auto element_type = node.get_output_element_type(0);
if (element_type == element::i64) if (element_type == element::i64)
{ {
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]), reference::argmax<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmax->get_reduction_axis()); argmax->get_reduction_axis());
} }
else if (element_type == element::i32) else if (element_type == element::i32)
{ {
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]), reference::argmax<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmax->get_reduction_axis()); argmax->get_reduction_axis());
...@@ -318,22 +344,22 @@ private: ...@@ -318,22 +344,22 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>( reference::asin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Atan: case OP_TYPEID::Atan:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>( reference::atan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::AvgPool: case OP_TYPEID::AvgPool:
{ {
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node); const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(static_cast<const T*>(args[0]), reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
avg_pool->get_window_shape(), avg_pool->get_window_shape(),
...@@ -345,18 +371,30 @@ private: ...@@ -345,18 +371,30 @@ private:
} }
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
bool use_seed = static_cast<bool>(args[2]->get_data_ptr<const int32_t>()[0]);
if (m_states.count(&node) == 0) if (m_states.count(&node) == 0)
{ {
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node); const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>( m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); ngraph::RNGState::create_rng_state(seed, gm->get_probability()));
} }
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]); bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get(); auto state = m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
if (!use_seed)
{
reference::generate_mask<T>( reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training); out[0]->get_data_ptr<T>(), element_count, state, training);
}
else
{
uint64_t seed = static_cast<uint64_t>(args[3]->get_data_ptr<const T>()[0]);
double prob = static_cast<double>(args[4]->get_data_ptr<const T>()[0]);
reference::generate_mask_no_state<T>(
out[0]->get_data_ptr<T>(), element_count, training, seed, prob);
}
break; break;
} }
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
...@@ -366,20 +404,31 @@ private: ...@@ -366,20 +404,31 @@ private:
size_t n = get_output_element->get_n(); size_t n = get_output_element->get_n();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size(); size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes); std::memcpy(out[0]->get_data_ptr<T>(), args[n]->get_data_ptr<T>(), num_bytes);
break;
}
case OP_TYPEID::BatchMatMul:
{
reference::batch_mat_mul(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
break; break;
} }
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
{ {
const ngraph::op::BatchNormTraining* bn = const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node); static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(), reference::batch_norm_training<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
static_cast<T*>(out[2]), out[2]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
...@@ -388,12 +437,12 @@ private: ...@@ -388,12 +437,12 @@ private:
const ngraph::op::BatchNormInference* bn = const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node); static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(), reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<const T*>(args[3]), args[3]->get_data_ptr<const T>(),
static_cast<const T*>(args[4]), args[4]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
...@@ -402,23 +451,23 @@ private: ...@@ -402,23 +451,23 @@ private:
const ngraph::op::BatchNormTrainingBackprop* bn_bprop = const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node); static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(), reference::batch_norm_backprop(bn_bprop->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<const T*>(args[3]), args[3]->get_data_ptr<const T>(),
static_cast<const T*>(args[4]), args[4]->get_data_ptr<const T>(),
static_cast<const T*>(args[5]), args[5]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
static_cast<T*>(out[2]), out[2]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
{ {
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node); const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]), reference::avg_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
apb->get_window_shape(), apb->get_window_shape(),
...@@ -434,8 +483,8 @@ private: ...@@ -434,8 +483,8 @@ private:
Shape in_shape = node.get_input_shape(0); Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0); Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes(); AxisSet broadcast_axes = broadcast->get_broadcast_axes();
gcpu::kernel::broadcast<T>(static_cast<const T*>(args[0]), kernel::broadcast<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
in_shape, in_shape,
out_shape, out_shape,
broadcast_axes); broadcast_axes);
...@@ -443,23 +492,28 @@ private: ...@@ -443,23 +492,28 @@ private:
} }
case OP_TYPEID::BroadcastDistributed: case OP_TYPEID::BroadcastDistributed:
{ {
int rank_ID = get_distributed_interface()->get_rank(); const ngraph::op::BroadcastDistributed* broadcast =
if (rank_ID == 0) static_cast<const ngraph::op::BroadcastDistributed*>(&node);
int rank_ID;
rank_ID = get_distributed_interface()->get_rank();
int root_id = broadcast->get_root_id();
if (rank_ID == root_id)
{ {
reference::broadcastdistributed<T>( reference::broadcastdistributed<T>(
static_cast<T*>(args[0]), args[0]->get_data_ptr<T>(),
node.get_input_element_type(0), node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))),
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * root_id);
sizeof(node.get_input_element_type(0)); auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
memcpy(out[0], args[0], memSize); memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
} }
else else
{ {
reference::broadcastdistributed<T>( reference::broadcastdistributed<T>(
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_element_type(0), node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
} }
break; break;
} }
...@@ -468,7 +522,7 @@ private: ...@@ -468,7 +522,7 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>( reference::ceiling<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
...@@ -478,11 +532,11 @@ private: ...@@ -478,11 +532,11 @@ private:
std::vector<Shape> in_shapes; std::vector<Shape> in_shapes;
for (size_t i = 0; i < node.get_input_size(); i++) for (size_t i = 0; i < node.get_input_size(); i++)
{ {
in_args.push_back(static_cast<const T*>(args[i])); in_args.push_back(args[i]->get_data_ptr<const T>());
in_shapes.push_back(node.get_input_shape(i)); in_shapes.push_back(node.get_input_shape(i));
} }
reference::concat<T>(in_args, reference::concat<T>(in_args,
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
in_shapes, in_shapes,
node.get_output_shape(0), node.get_output_shape(0),
concat->get_concatenation_axis()); concat->get_concatenation_axis());
...@@ -492,7 +546,7 @@ private: ...@@ -492,7 +546,7 @@ private:
{ {
const op::Constant* c = static_cast<const op::Constant*>(&node); const op::Constant* c = static_cast<const op::Constant*>(&node);
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::constant<T>(c->get_data_ptr<T>(), static_cast<T*>(out[0]), element_count); reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ScalarConstantLike: break; case OP_TYPEID::ScalarConstantLike: break;
...@@ -505,52 +559,62 @@ private: ...@@ -505,52 +559,62 @@ private:
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
case element::Type_t::boolean: case element::Type_t::boolean:
reference::convert<T>( reference::convert_to_bool<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
reference::convert<T>( reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count); out[0]->get_data_ptr<double>(),
element_count);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count); out[0]->get_data_ptr<int8_t>(),
element_count);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count); out[0]->get_data_ptr<int16_t>(),
element_count);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count); out[0]->get_data_ptr<int32_t>(),
element_count);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count); out[0]->get_data_ptr<int64_t>(),
element_count);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count); out[0]->get_data_ptr<uint8_t>(),
element_count);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count); out[0]->get_data_ptr<uint16_t>(),
element_count);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count); out[0]->get_data_ptr<uint32_t>(),
element_count);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count); out[0]->get_data_ptr<uint64_t>(),
element_count);
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert"; ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
...@@ -559,9 +623,9 @@ private: ...@@ -559,9 +623,9 @@ private:
case OP_TYPEID::Convolution: case OP_TYPEID::Convolution:
{ {
const op::Convolution* c = static_cast<const op::Convolution*>(&node); const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]), reference::convolution<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
...@@ -569,38 +633,26 @@ private: ...@@ -569,38 +633,26 @@ private:
c->get_window_dilation_strides(), c->get_window_dilation_strides(),
c->get_padding_below(), c->get_padding_below(),
c->get_padding_above(), c->get_padding_above(),
c->get_data_dilation_strides(), c->get_data_dilation_strides());
0,
1,
1,
0,
0,
1,
false);
break; break;
} }
case OP_TYPEID::ConvolutionBackpropFilters: case OP_TYPEID::ConvolutionBackpropFilters:
{ {
const op::ConvolutionBackpropFilters* c = const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node); static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]), reference::convolution_backprop_filter<T>(
static_cast<const T*>(args[1]), args[0]->get_data_ptr<const T>(), // input
static_cast<T*>(out[0]), args[1]->get_data_ptr<const T>(), // delta_convolution_output
node.get_input_shape(0), out[0]->get_data_ptr<T>(), // delta_filter
node.get_input_shape(1), c->get_input_shape(0), // input_shape
node.get_output_shape(0), c->get_input_shape(1), // convolution_output_shape
c->get_window_movement_strides_backward(), c->get_filters_shape(), // filter_shape
c->get_window_dilation_strides_backward(), c->get_window_dilation_strides_forward(),
c->get_padding_below_backward(), c->get_window_movement_strides_forward(),
c->get_padding_above_backward(), c->get_padding_below_forward(),
c->get_data_dilation_strides_backward(), c->compute_backward_in_pad_above(),
1, c->get_data_dilation_strides_forward());
0,
0,
1,
1,
0,
false);
break; break;
} }
case OP_TYPEID::ConvolutionBackpropData: case OP_TYPEID::ConvolutionBackpropData:
...@@ -608,38 +660,31 @@ private: ...@@ -608,38 +660,31 @@ private:
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c = const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node); static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(static_cast<const T*>(args[1]), reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), c->get_input_shape(1),
node.get_input_shape(0), c->get_input_shape(0),
node.get_output_shape(0), c->get_data_batch_shape(),
c->get_window_movement_strides_backward(), c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_backward(), c->get_window_dilation_strides_forward(),
c->get_padding_below_backward(), c->compute_backward_delta_out_pad_below(),
c->get_padding_above_backward(), c->compute_backward_delta_out_pad_above(),
c->get_data_dilation_strides_backward(), c->get_window_movement_strides_forward());
0,
1,
0,
1,
0,
1,
true);
break; break;
} }
case OP_TYPEID::Cos: case OP_TYPEID::Cos:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>( reference::cos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Cosh: case OP_TYPEID::Cosh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>( reference::cosh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
...@@ -649,20 +694,20 @@ private: ...@@ -649,20 +694,20 @@ private:
if (type == element::f32) if (type == element::f32)
{ {
reference::dequantize<T>(static_cast<const T*>(args[0]), reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const float*>(args[1]), args[1]->get_data_ptr<const float>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<float*>(out[0]), out[0]->get_data_ptr<float>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
dequantize->get_axes()); dequantize->get_axes());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::dequantize<T>(static_cast<const T*>(args[0]), reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const double*>(args[1]), args[1]->get_data_ptr<const double>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<double*>(out[0]), out[0]->get_data_ptr<double>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
dequantize->get_axes()); dequantize->get_axes());
...@@ -680,9 +725,9 @@ private: ...@@ -680,9 +725,9 @@ private:
{ {
const op::Divide* divop = static_cast<const op::Divide*>(&node); const op::Divide* divop = static_cast<const op::Divide*>(&node);
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]), reference::divide<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
divop->is_pythondiv()); divop->is_pythondiv());
break; break;
...@@ -691,15 +736,25 @@ private: ...@@ -691,15 +736,25 @@ private:
{ {
const op::Dot* dot = static_cast<const op::Dot*>(&node); const op::Dot* dot = static_cast<const op::Dot*>(&node);
gcpu::kernel::dot(static_cast<const T*>(args[0]), kernel::dot(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
dot->get_reduction_axes_count()); dot->get_reduction_axes_count());
break; break;
} }
case OP_TYPEID::DynReshape:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::DynSlice:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::EmbeddingLookup: case OP_TYPEID::EmbeddingLookup:
{ {
const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node); const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node);
...@@ -708,33 +763,33 @@ private: ...@@ -708,33 +763,33 @@ private:
if (type == element::f32) if (type == element::f32)
{ {
reference::embedding<T, float>(static_cast<const float*>(args[0]), reference::embedding<T, float>(args[0]->get_data_ptr<const float>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::embedding<T, double>(static_cast<const double*>(args[0]), reference::embedding<T, double>(args[0]->get_data_ptr<const double>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::embedding<T, int>(static_cast<const int*>(args[0]), reference::embedding<T, int32_t>(args[0]->get_data_ptr<const int>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::i64) else if (type == element::i64)
{ {
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]), reference::embedding<T, int64_t>(args[0]->get_data_ptr<const int64_t>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
...@@ -748,24 +803,56 @@ private: ...@@ -748,24 +803,56 @@ private:
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]), reference::equal<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Erf:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::erf<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Exp: case OP_TYPEID::Exp:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>( reference::exp<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
#ifdef INTERPRETER_USE_HYBRID
case OP_TYPEID::FunctionCall:
{
auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
auto backend = f->get_backend();
auto executable = f->get_executable();
std::vector<std::shared_ptr<Tensor>> outputs;
std::vector<std::shared_ptr<Tensor>> inputs;
for (const std::shared_ptr<HostTensor>& t : out)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
outputs.push_back(backend_tensor);
}
for (const std::shared_ptr<HostTensor>& t : args)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
inputs.push_back(backend_tensor);
}
executable->call(outputs, inputs);
break; break;
} }
#endif
case OP_TYPEID::Floor: case OP_TYPEID::Floor:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>( reference::floor<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
...@@ -826,36 +913,36 @@ private: ...@@ -826,36 +913,36 @@ private:
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]), reference::greater<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]), reference::greater_eq<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]), reference::less<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]), reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
...@@ -863,14 +950,14 @@ private: ...@@ -863,14 +950,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>( reference::log<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::LRN: case OP_TYPEID::LRN:
{ {
const op::LRN* lrn = static_cast<const op::LRN*>(&node); const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(static_cast<const T*>(args[0]), reference::lrn<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
lrn->get_alpha(), lrn->get_alpha(),
lrn->get_beta(), lrn->get_beta(),
...@@ -881,8 +968,8 @@ private: ...@@ -881,8 +968,8 @@ private:
case OP_TYPEID::Max: case OP_TYPEID::Max:
{ {
const op::Max* max = static_cast<const op::Max*>(&node); const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(static_cast<const T*>(args[0]), reference::max<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
max->get_reduction_axes()); max->get_reduction_axes());
...@@ -891,9 +978,9 @@ private: ...@@ -891,9 +978,9 @@ private:
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]), reference::maximum<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -901,8 +988,8 @@ private: ...@@ -901,8 +988,8 @@ private:
{ {
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node); const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(static_cast<const T*>(args[0]), reference::max_pool<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
max_pool->get_window_shape(), max_pool->get_window_shape(),
...@@ -916,9 +1003,9 @@ private: ...@@ -916,9 +1003,9 @@ private:
const op::MaxPoolBackprop* max_pool_backprop = const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node); static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]), reference::max_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
max_pool_backprop->get_window_shape(), max_pool_backprop->get_window_shape(),
...@@ -930,8 +1017,8 @@ private: ...@@ -930,8 +1017,8 @@ private:
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
const op::Min* min = static_cast<const op::Min*>(&node); const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(static_cast<const T*>(args[0]), reference::min<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
min->get_reduction_axes()); min->get_reduction_axes());
...@@ -940,18 +1027,18 @@ private: ...@@ -940,18 +1027,18 @@ private:
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]), reference::minimum<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]), reference::multiply<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -959,30 +1046,30 @@ private: ...@@ -959,30 +1046,30 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>( reference::negate<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Not: case OP_TYPEID::Not:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not( reference::logical_not(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]), reference::not_equal<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::OneHot: case OP_TYPEID::OneHot:
{ {
const op::OneHot* oh = static_cast<const op::OneHot*>(&node); const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(static_cast<const T*>(args[0]), reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
oh->get_one_hot_axis()); oh->get_one_hot_axis());
...@@ -991,46 +1078,46 @@ private: ...@@ -991,46 +1078,46 @@ private:
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]), reference::logical_or(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Parameter: break; case OP_TYPEID::Parameter: break;
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
{ {
const op::Pad* pad = static_cast<const op::Pad*>(&node); const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(static_cast<const T*>(args[0]), reference::pad(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(), node.input(0).get_shape(),
node.get_output_shape(0), node.output(0).get_shape(),
pad->get_padding_below(), pad->get_padding_below(),
pad->get_padding_above(), pad->get_padding_above(),
pad->get_padding_interior()); pad->get_pad_mode());
break; break;
} }
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]), reference::power<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
const op::Product* product = static_cast<const op::Product*>(&node); const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(static_cast<const T*>(args[0]), reference::product<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
product->get_reduction_axes()); product->get_reduction_axes());
...@@ -1043,10 +1130,10 @@ private: ...@@ -1043,10 +1130,10 @@ private:
if (type == element::u8) if (type == element::u8)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const uint8_t*>(args[2]), args[2]->get_data_ptr<const uint8_t>(),
static_cast<uint8_t*>(out[0]), out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1054,10 +1141,10 @@ private: ...@@ -1054,10 +1141,10 @@ private:
} }
else if (type == element::i8) else if (type == element::i8)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const int8_t*>(args[2]), args[2]->get_data_ptr<const int8_t>(),
static_cast<int8_t*>(out[0]), out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1065,10 +1152,10 @@ private: ...@@ -1065,10 +1152,10 @@ private:
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const int32_t*>(args[2]), args[2]->get_data_ptr<const int32_t>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1083,40 +1170,168 @@ private: ...@@ -1083,40 +1170,168 @@ private:
break; break;
} }
case OP_TYPEID::QuantizedConvolution:
{
const op::QuantizedConvolution* qc =
static_cast<const op::QuantizedConvolution*>(&node);
auto input_element_type = qc->get_input_element_type(0);
auto filter_element_type = qc->get_input_element_type(1);
auto output_element_type = qc->get_output_element_type(0);
if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i8)
{
reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::u8)
{
reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const uint8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else
{
std::stringstream ss;
ss << "unsupported element type";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::QuantizedAvgPool: case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias: case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd: case OP_TYPEID::QuantizedConvolutionBiasAdd:
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu: case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool: case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::QuantizedDotBias: case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedDot: case OP_TYPEID::QuantizedDot:
{ {
throw unsupported_op("Unsupported op '" + node.description() + "'."); throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Range:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
} }
case OP_TYPEID::Relu: case OP_TYPEID::Relu:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>( reference::relu<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ReluBackprop: case OP_TYPEID::ReluBackprop:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]), reference::relu_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
{ {
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node); const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(static_cast<const T*>(args[0]), reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -1127,8 +1342,8 @@ private: ...@@ -1127,8 +1342,8 @@ private:
case OP_TYPEID::Reshape: case OP_TYPEID::Reshape:
{ {
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node); const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
gcpu::kernel::reshape(static_cast<const T*>(args[0]), kernel::reshape(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reshape->get_input_order(), reshape->get_input_order(),
node.get_output_shape(0)); node.get_output_shape(0));
...@@ -1137,16 +1352,16 @@ private: ...@@ -1137,16 +1352,16 @@ private:
case OP_TYPEID::Result: case OP_TYPEID::Result:
{ {
const op::Result* res = static_cast<const op::Result*>(&node); const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(static_cast<const T*>(args[0]), reference::result(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
shape_size(res->get_shape())); shape_size(res->get_shape()));
break; break;
} }
case OP_TYPEID::Reverse: case OP_TYPEID::Reverse:
{ {
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node); const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(static_cast<const T*>(args[0]), reference::reverse(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
reverse->get_reversed_axes()); reverse->get_reversed_axes());
...@@ -1158,12 +1373,12 @@ private: ...@@ -1158,12 +1373,12 @@ private:
if (node.get_input_element_type(1) == element::i32) if (node.get_input_element_type(1) == element::i32)
{ {
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]), reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reverse->get_batch_axis(), reverse->get_batch_axis(),
reverse->get_sequence_axis(), reverse->get_sequence_axis(),
static_cast<const int32_t*>(args[1])); args[1]->get_data_ptr<const int32_t>());
} }
else else
{ {
...@@ -1234,31 +1449,46 @@ private: ...@@ -1234,31 +1449,46 @@ private:
case OP_TYPEID::Select: case OP_TYPEID::Select:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]), reference::select<T>(args[0]->get_data_ptr<const char>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
{ {
reference::shape_of(node.get_input_shape(0), static_cast<uint64_t*>(out[0])); reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
break; break;
} }
case OP_TYPEID::Sigmoid: case OP_TYPEID::Sigmoid:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>( reference::sigmoid<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::SigmoidBackprop: case OP_TYPEID::SigmoidBackprop:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]), reference::sigmoid_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -1266,28 +1496,28 @@ private: ...@@ -1266,28 +1496,28 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>( reference::sign<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Sin: case OP_TYPEID::Sin:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>( reference::sin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Sinh: case OP_TYPEID::Sinh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>( reference::sinh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Slice: case OP_TYPEID::Slice:
{ {
const op::Slice* slice = static_cast<const op::Slice*>(&node); const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(static_cast<const T*>(args[0]), reference::slice<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -1298,8 +1528,8 @@ private: ...@@ -1298,8 +1528,8 @@ private:
case OP_TYPEID::Softmax: case OP_TYPEID::Softmax:
{ {
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node); const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(static_cast<const T*>(args[0]), reference::softmax<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_output_shape(0), node.get_output_shape(0),
softmax->get_axes()); softmax->get_axes());
break; break;
...@@ -1308,7 +1538,7 @@ private: ...@@ -1308,7 +1538,7 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>( reference::sqrt<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'"); case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
...@@ -1316,17 +1546,17 @@ private: ...@@ -1316,17 +1546,17 @@ private:
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]), reference::subtract<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
const op::Sum* sum = static_cast<const op::Sum*>(&node); const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(static_cast<const T*>(args[0]), reference::sum<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
sum->get_reduction_axes()); sum->get_reduction_axes());
...@@ -1336,14 +1566,14 @@ private: ...@@ -1336,14 +1566,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>( reference::tan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Tanh: case OP_TYPEID::Tanh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>( reference::tanh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
...@@ -1351,9 +1581,9 @@ private: ...@@ -1351,9 +1581,9 @@ private:
const op::TopK* topk = static_cast<const op::TopK*>(&node); const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64) if (node.get_output_element_type(0) == element::i64)
{ {
reference::topk<T, int64_t>(static_cast<const T*>(args[0]), reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
...@@ -1362,9 +1592,9 @@ private: ...@@ -1362,9 +1592,9 @@ private:
} }
else if (node.get_output_element_type(0) == element::i32) else if (node.get_output_element_type(0) == element::i32)
{ {
reference::topk<T, int32_t>(static_cast<const T*>(args[0]), reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
...@@ -1377,7 +1607,12 @@ private: ...@@ -1377,7 +1607,12 @@ private:
} }
break; break;
} }
default: throw unsupported_op("Unsupported op '" + node.description() + "'"); case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
case OP_TYPEID::Tile:
case OP_TYPEID::DynReplaceSlice:
throw unsupported_op("Unsupported op '" + node.description() + "'");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
......
...@@ -140,6 +140,91 @@ namespace ngraph ...@@ -140,6 +140,91 @@ namespace ngraph
} }
} }
template <typename T>
void broadcast_5d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[5];
size_t* out_index = 0;
for (size_t i = 0; i < 5; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] *
out_shape[4] +
index[1] * out_shape[2] * out_shape[3] * out_shape[4] +
index[2] * out_shape[3] * out_shape[4] +
index[3] * out_shape[4] + index[4]] = in[*out_index];
}
}
}
}
}
}
template <typename T>
void broadcast_6d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[6];
size_t* out_index = 0;
for (size_t i = 0; i < 6; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
out[index[0] * out_shape[1] * out_shape[2] *
out_shape[3] * out_shape[4] * out_shape[5] +
index[1] * out_shape[2] * out_shape[3] *
out_shape[4] * out_shape[5] +
index[2] * out_shape[3] * out_shape[4] *
out_shape[5] +
index[3] * out_shape[4] * out_shape[5] +
index[4] * out_shape[5] + index[5]] =
in[*out_index];
}
}
}
}
}
}
}
template <typename T> template <typename T>
void broadcast(const T* in, void broadcast(const T* in,
T* out, T* out,
...@@ -167,6 +252,16 @@ namespace ngraph ...@@ -167,6 +252,16 @@ namespace ngraph
case 4: case 4:
broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes); broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes);
break; break;
case 5:
broadcast_5d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 6:
broadcast_6d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
} }
} }
else else
......
...@@ -244,10 +244,7 @@ namespace ngraph ...@@ -244,10 +244,7 @@ namespace ngraph
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break; case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break; case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break; case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default: default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
NGRAPH_INFO << "reference::reshape";
reference::reshape(in, out, in_shape, in_axis_order, out_shape);
break;
} }
} }
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void result(const T* arg, T* out, size_t count)
{
memcpy(out, arg, sizeof(T) * count);
}
}
}
}
}
...@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper ...@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
public: public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node); NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; } std::shared_ptr<const Node> get_node() const { return m_node; }
ngraph::runtime::gcpu::OP_TYPEID get_typeid() const { return m_typeid; } ngraph::runtime::gcpu::OP_TYPEID get_typeid() const { return m_typeid; }
private: private:
std::shared_ptr<const ngraph::Node> m_node; std::shared_ptr<const ngraph::Node> m_node;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment