Unverified Commit 47342339 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3178 from NervanaSystems/bob/gcpu

Update generic CPU backend to latest ngraph API
parents 30527e80 c7630c05
......@@ -15,10 +15,10 @@
# ******************************************************************************
if (NGRAPH_GENERIC_CPU_ENABLE)
find_package(OpenMP)
if (OPENMP_FOUND)
add_compile_options(${OpenMP_CXX_FLAGS})
endif()
# find_package(OpenMP)
# if (OPENMP_FOUND)
# add_compile_options(${OpenMP_CXX_FLAGS})
# endif()
add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp)
if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(gcpu_backend PROPERTIES
......
......@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape)
{
return make_shared<runtime::HostTensor>(type, shape, this);
return make_shared<runtime::HostTensor>(type, shape);
}
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer)
{
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, this);
return make_shared<runtime::HostTensor>(type, shape, memory_pointer);
}
shared_ptr<runtime::Executable>
......
......@@ -15,17 +15,22 @@
//*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
runtime::gcpu::GCPUExecutable::GCPUExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection)
: m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection}
{
m_function = clone_function(*function);
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(m_function);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{
m_is_compiled = true;
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function);
m_wrapped_nodes.emplace_back(node);
}
set_parameters_and_results(*m_function);
}
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
}
runtime::gcpu::GCPUExecutable::GCPUExecutable(const std::string& model_string)
: m_is_compiled{true}
, m_performance_counters_enabled{false}
{
m_function = deserialize(model_string);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
}
set_parameters_and_results(*function);
set_parameters_and_results(*m_function);
}
bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
......@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
descriptor::Tensor* tensor = param->get_output_tensor_ptr(i).get();
descriptor::Tensor* tensor = &param->output(i).get_tensor();
tensor_map.insert({tensor, func_inputs[input_count++]});
}
}
......@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
throw ngraph_error("One of function's outputs isn't op::Result");
}
descriptor::Tensor* tensor = output->get_output_tensor_ptr(0).get();
descriptor::Tensor* tensor = &output->output(0).get_tensor();
tensor_map.insert({tensor, func_outputs[output_count]});
}
// for each ordered op in the graph
for (const NodeWrapper& wrapped : m_wrapped_nodes)
{
const Node* op = &wrapped.get_node();
auto op = wrapped.get_node();
auto type_id = wrapped.get_typeid();
if (type_id == OP_TYPEID::Parameter)
{
......@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
// get op inputs from map
vector<shared_ptr<HostTensor>> op_inputs;
for (const descriptor::Input& input : op->get_inputs())
for (auto input : op->inputs())
{
descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get();
descriptor::Tensor* tensor = &input.get_tensor();
op_inputs.push_back(tensor_map.at(tensor));
}
......@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
vector<shared_ptr<HostTensor>> op_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i)
{
descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get();
descriptor::Tensor* tensor = &op->output(i).get_tensor();
shared_ptr<HostTensor> host_tensor;
auto it = tensor_map.find(tensor);
if (it == tensor_map.end())
{
const Shape& shape = op->get_output_shape(i);
const element::Type& type = op->get_output_element_type(i);
string name = op->get_output_tensor(i).get_name();
string name = op->output(i).get_tensor().get_name();
host_tensor = make_shared<runtime::HostTensor>(type, shape, name);
tensor_map.insert({tensor, host_tensor});
}
......@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
}
if (m_nan_check_enabled)
{
perform_nan_check(op_outputs, op);
perform_nan_check(op_outputs, op.get());
}
}
......@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs,
const vector<shared_ptr<HostTensor>>& inputs)
const vector<shared_ptr<HostTensor>>& out,
const vector<shared_ptr<HostTensor>>& in)
{
vector<void*> out;
vector<const void*> in;
for (auto t : outputs)
{
out.push_back(t->get_data_ptr());
}
for (auto t : inputs)
{
in.push_back(t->get_data_ptr());
}
stringstream ss;
switch (type.get_type_enum())
{
......@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
ss << "unsupported element type " << type << " op " << op.get_node().get_name();
case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
throw ngraph_error(ss.str());
}
}
......@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
vector<runtime::PerformanceCounter> runtime::gcpu::GCPUExecutable::get_performance_data() const
{
vector<runtime::PerformanceCounter> rc;
for (const pair<const Node*, stopwatch> p : m_timer_map)
for (const pair<shared_ptr<const Node>, stopwatch> p : m_timer_map)
{
rc.emplace_back(p.first->get_name().c_str(),
p.second.get_total_microseconds(),
p.second.get_call_count());
rc.emplace_back(p.first, p.second.get_total_microseconds(), p.second.get_call_count());
}
return rc;
}
......@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
arg_number++;
}
}
void runtime::gcpu::GCPUExecutable::save(ostream& out)
{
cpio::Writer writer(out);
string si = "INTERPRETER Save File 1.0";
writer.write("save_info", si.data(), si.size());
string model = serialize(m_function, 0);
writer.write("model", model.data(), model.size());
}
......@@ -17,24 +17,31 @@
#pragma once
#include <initializer_list>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
......@@ -48,11 +55,14 @@
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
......@@ -64,7 +74,6 @@
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
......@@ -77,7 +86,9 @@
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
......@@ -89,8 +100,10 @@
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
......@@ -117,14 +130,17 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
......@@ -134,6 +150,7 @@
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
......@@ -154,6 +171,8 @@ namespace ngraph
class ngraph::runtime::gcpu::GCPUExecutable : public Executable
{
friend class GCPUBackend;
public:
GCPUExecutable(const std::shared_ptr<Function>& function,
bool enable_performance_collection = false);
......@@ -161,20 +180,25 @@ public:
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
virtual void save(std::ostream& output_stream) override;
void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override;
private:
GCPUExecutable(const std::string& model_string);
int get_alignment() const { return 64; }
bool m_is_compiled = false;
bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::set<std::string> m_unsupported_op_name_list;
int get_alignment() const { return 64; }
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr);
......@@ -185,11 +209,10 @@ private:
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out,
const std::vector<const void*>& args)
const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
const Node& node = *node_wrapper.get_node();
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
......@@ -206,30 +229,30 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Acos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Add:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::add<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::All:
{
const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
reference::all(args[0]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_output_shape(0),
all->get_reduction_axes());
......@@ -237,26 +260,29 @@ private:
}
case OP_TYPEID::AllReduce:
{
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])),
static_cast<T*>(out[0]),
node.get_input_element_type(0),
const ngraph::op::AllReduce* allreduce =
static_cast<const ngraph::op::AllReduce*>(&node);
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
allreduce->get_reduce_type(),
static_cast<int>(shape_size(node.get_input_shape(0))));
break;
}
case OP_TYPEID::And:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::logical_and(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Any:
{
const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
reference::any(args[0]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_output_shape(0),
any->get_reduction_axes());
......@@ -268,16 +294,16 @@ private:
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
reference::argmin<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
}
else if (element_type == element::i32)
{
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
reference::argmin<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
......@@ -294,16 +320,16 @@ private:
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
reference::argmax<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
}
else if (element_type == element::i32)
{
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
reference::argmax<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
......@@ -318,22 +344,22 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Atan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::AvgPool:
{
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
avg_pool->get_window_shape(),
......@@ -345,18 +371,30 @@ private:
}
case OP_TYPEID::GenerateMask:
{
bool use_seed = static_cast<bool>(args[2]->get_data_ptr<const int32_t>()[0]);
if (m_states.count(&node) == 0)
{
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
ngraph::RNGState::create_rng_state(seed, gm->get_probability()));
}
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]);
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training);
if (!use_seed)
{
reference::generate_mask<T>(
out[0]->get_data_ptr<T>(), element_count, state, training);
}
else
{
uint64_t seed = static_cast<uint64_t>(args[3]->get_data_ptr<const T>()[0]);
double prob = static_cast<double>(args[4]->get_data_ptr<const T>()[0]);
reference::generate_mask_no_state<T>(
out[0]->get_data_ptr<T>(), element_count, training, seed, prob);
}
break;
}
case OP_TYPEID::GetOutputElement:
......@@ -366,20 +404,31 @@ private:
size_t n = get_output_element->get_n();
size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes);
std::memcpy(out[0]->get_data_ptr<T>(), args[n]->get_data_ptr<T>(), num_bytes);
break;
}
case OP_TYPEID::BatchMatMul:
{
reference::batch_mat_mul(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
break;
}
case OP_TYPEID::BatchNormTraining:
{
const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
out[2]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
......@@ -388,12 +437,12 @@ private:
const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<T*>(out[0]),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const T>(),
args[4]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
......@@ -402,23 +451,23 @@ private:
const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<const T*>(args[5]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const T>(),
args[4]->get_data_ptr<const T>(),
args[5]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
out[2]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
case OP_TYPEID::AvgPoolBackprop:
{
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::avg_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
apb->get_window_shape(),
......@@ -434,32 +483,37 @@ private:
Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes();
gcpu::kernel::broadcast<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
in_shape,
out_shape,
broadcast_axes);
kernel::broadcast<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
in_shape,
out_shape,
broadcast_axes);
break;
}
case OP_TYPEID::BroadcastDistributed:
{
int rank_ID = get_distributed_interface()->get_rank();
if (rank_ID == 0)
const ngraph::op::BroadcastDistributed* broadcast =
static_cast<const ngraph::op::BroadcastDistributed*>(&node);
int rank_ID;
rank_ID = get_distributed_interface()->get_rank();
int root_id = broadcast->get_root_id();
if (rank_ID == root_id)
{
reference::broadcastdistributed<T>(
static_cast<T*>(args[0]),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))));
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) *
sizeof(node.get_input_element_type(0));
memcpy(out[0], args[0], memSize);
args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
}
else
{
reference::broadcastdistributed<T>(
static_cast<T*>(out[0]),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))));
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
}
break;
}
......@@ -468,7 +522,7 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Concat:
......@@ -478,11 +532,11 @@ private:
std::vector<Shape> in_shapes;
for (size_t i = 0; i < node.get_input_size(); i++)
{
in_args.push_back(static_cast<const T*>(args[i]));
in_args.push_back(args[i]->get_data_ptr<const T>());
in_shapes.push_back(node.get_input_shape(i));
}
reference::concat<T>(in_args,
static_cast<T*>(out[0]),
out[0]->get_data_ptr<T>(),
in_shapes,
node.get_output_shape(0),
concat->get_concatenation_axis());
......@@ -492,7 +546,7 @@ private:
{
const op::Constant* c = static_cast<const op::Constant*>(&node);
size_t element_count = shape_size(node.get_output_shape(0));
reference::constant<T>(c->get_data_ptr<T>(), static_cast<T*>(out[0]), element_count);
reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ScalarConstantLike: break;
......@@ -505,52 +559,62 @@ private:
switch (type.get_type_enum())
{
case element::Type_t::boolean:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count);
reference::convert_to_bool<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
break;
case element::Type_t::f32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
break;
case element::Type_t::f64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<double>(),
element_count);
break;
case element::Type_t::i8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int8_t>(),
element_count);
break;
case element::Type_t::i16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int16_t>(),
element_count);
break;
case element::Type_t::i32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
element_count);
break;
case element::Type_t::i64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
element_count);
break;
case element::Type_t::u8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint8_t>(),
element_count);
break;
case element::Type_t::u16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint16_t>(),
element_count);
break;
case element::Type_t::u32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint32_t>(),
element_count);
break;
case element::Type_t::u64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint64_t>(),
element_count);
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
......@@ -559,9 +623,9 @@ private:
case OP_TYPEID::Convolution:
{
const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::convolution<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
......@@ -569,38 +633,26 @@ private:
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
c->get_data_dilation_strides());
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
{
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
reference::convolution_backprop_filter<T>(
args[0]->get_data_ptr<const T>(), // input
args[1]->get_data_ptr<const T>(), // delta_convolution_output
out[0]->get_data_ptr<T>(), // delta_filter
c->get_input_shape(0), // input_shape
c->get_input_shape(1), // convolution_output_shape
c->get_filters_shape(), // filter_shape
c->get_window_dilation_strides_forward(),
c->get_window_movement_strides_forward(),
c->get_padding_below_forward(),
c->compute_backward_in_pad_above(),
c->get_data_dilation_strides_forward());
break;
}
case OP_TYPEID::ConvolutionBackpropData:
......@@ -608,38 +660,31 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(static_cast<const T*>(args[1]),
static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
node.get_input_shape(0),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
c->get_input_shape(1),
c->get_input_shape(0),
c->get_data_batch_shape(),
c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_forward(),
c->compute_backward_delta_out_pad_below(),
c->compute_backward_delta_out_pad_above(),
c->get_window_movement_strides_forward());
break;
}
case OP_TYPEID::Cos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Cosh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Dequantize:
......@@ -649,20 +694,20 @@ private:
if (type == element::f32)
{
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const float*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<float*>(out[0]),
reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const float>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<float>(),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else if (type == element::f64)
{
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const double*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<double*>(out[0]),
reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const double>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<double>(),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
......@@ -680,9 +725,9 @@ private:
{
const op::Divide* divop = static_cast<const op::Divide*>(&node);
size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::divide<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
divop->is_pythondiv());
break;
......@@ -691,13 +736,23 @@ private:
{
const op::Dot* dot = static_cast<const op::Dot*>(&node);
gcpu::kernel::dot(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
dot->get_reduction_axes_count());
kernel::dot(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
dot->get_reduction_axes_count());
break;
}
case OP_TYPEID::DynReshape:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::DynSlice:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::EmbeddingLookup:
......@@ -708,33 +763,33 @@ private:
if (type == element::f32)
{
reference::embedding<T, float>(static_cast<const float*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, float>(args[0]->get_data_ptr<const float>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
else if (type == element::f64)
{
reference::embedding<T, double>(static_cast<const double*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, double>(args[0]->get_data_ptr<const double>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
else if (type == element::i32)
{
reference::embedding<T, int>(static_cast<const int*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
reference::embedding<T, int32_t>(args[0]->get_data_ptr<const int>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
else if (type == element::i64)
{
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, int64_t>(args[0]->get_data_ptr<const int64_t>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
......@@ -748,24 +803,56 @@ private:
case OP_TYPEID::Equal:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::equal<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::Erf:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::erf<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Exp:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
#ifdef INTERPRETER_USE_HYBRID
case OP_TYPEID::FunctionCall:
{
auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
auto backend = f->get_backend();
auto executable = f->get_executable();
std::vector<std::shared_ptr<Tensor>> outputs;
std::vector<std::shared_ptr<Tensor>> inputs;
for (const std::shared_ptr<HostTensor>& t : out)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
outputs.push_back(backend_tensor);
}
for (const std::shared_ptr<HostTensor>& t : args)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
inputs.push_back(backend_tensor);
}
executable->call(outputs, inputs);
break;
}
#endif
case OP_TYPEID::Floor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Gather:
......@@ -826,36 +913,36 @@ private:
case OP_TYPEID::Greater:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::greater<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::GreaterEq:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::greater_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::Less:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::less<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::LessEq:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
......@@ -863,14 +950,14 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::lrn<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
lrn->get_alpha(),
lrn->get_beta(),
......@@ -881,8 +968,8 @@ private:
case OP_TYPEID::Max:
{
const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::max<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
max->get_reduction_axes());
......@@ -891,9 +978,9 @@ private:
case OP_TYPEID::Maximum:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::maximum<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -901,8 +988,8 @@ private:
{
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::max_pool<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
max_pool->get_window_shape(),
......@@ -916,9 +1003,9 @@ private:
const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::max_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
node.get_output_shape(0),
max_pool_backprop->get_window_shape(),
......@@ -930,8 +1017,8 @@ private:
case OP_TYPEID::Min:
{
const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::min<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
min->get_reduction_axes());
......@@ -940,18 +1027,18 @@ private:
case OP_TYPEID::Minimum:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::minimum<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Multiply:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::multiply<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -959,30 +1046,30 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::NotEqual:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::not_equal<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::OneHot:
{
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
oh->get_one_hot_axis());
......@@ -991,46 +1078,46 @@ private:
case OP_TYPEID::Or:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::logical_or(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Parameter: break;
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Pad:
{
const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
reference::pad(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.input(0).get_shape(),
node.output(0).get_shape(),
pad->get_padding_below(),
pad->get_padding_above(),
pad->get_padding_interior());
pad->get_pad_mode());
break;
}
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Power:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::power<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Product:
{
const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::product<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
product->get_reduction_axes());
......@@ -1043,10 +1130,10 @@ private:
if (type == element::u8)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const uint8_t*>(args[2]),
static_cast<uint8_t*>(out[0]),
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
......@@ -1054,10 +1141,10 @@ private:
}
else if (type == element::i8)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int8_t*>(args[2]),
static_cast<int8_t*>(out[0]),
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
......@@ -1065,10 +1152,10 @@ private:
}
else if (type == element::i32)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int32_t*>(args[2]),
static_cast<int32_t*>(out[0]),
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const int32_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
......@@ -1083,40 +1170,168 @@ private:
break;
}
case OP_TYPEID::QuantizedConvolution:
{
const op::QuantizedConvolution* qc =
static_cast<const op::QuantizedConvolution*>(&node);
auto input_element_type = qc->get_input_element_type(0);
auto filter_element_type = qc->get_input_element_type(1);
auto output_element_type = qc->get_output_element_type(0);
if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i8)
{
reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::u8)
{
reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const uint8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else
{
std::stringstream ss;
ss << "unsupported element type";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd:
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedDot:
{
throw unsupported_op("Unsupported op '" + node.description() + "'.");
throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Range:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ReluBackprop:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::relu_backprop<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::ReplaceSlice:
{
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
......@@ -1127,26 +1342,26 @@ private:
case OP_TYPEID::Reshape:
{
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
gcpu::kernel::reshape(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
reshape->get_input_order(),
node.get_output_shape(0));
kernel::reshape(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
reshape->get_input_order(),
node.get_output_shape(0));
break;
}
case OP_TYPEID::Result:
{
const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::result(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
shape_size(res->get_shape()));
break;
}
case OP_TYPEID::Reverse:
{
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::reverse(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
reverse->get_reversed_axes());
......@@ -1158,12 +1373,12 @@ private:
if (node.get_input_element_type(1) == element::i32)
{
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
reverse->get_batch_axis(),
reverse->get_sequence_axis(),
static_cast<const int32_t*>(args[1]));
args[1]->get_data_ptr<const int32_t>());
}
else
{
......@@ -1234,31 +1449,46 @@ private:
case OP_TYPEID::Select:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
reference::select<T>(args[0]->get_data_ptr<const char>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf:
{
reference::shape_of(node.get_input_shape(0), static_cast<uint64_t*>(out[0]));
reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
break;
}
case OP_TYPEID::Sigmoid:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::SigmoidBackprop:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::sigmoid_backprop<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -1266,28 +1496,28 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sinh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Slice:
{
const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::slice<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
......@@ -1298,8 +1528,8 @@ private:
case OP_TYPEID::Softmax:
{
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::softmax<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_output_shape(0),
softmax->get_axes());
break;
......@@ -1308,7 +1538,7 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
......@@ -1316,17 +1546,17 @@ private:
case OP_TYPEID::Subtract:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::subtract<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Sum:
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::sum<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
sum->get_reduction_axes());
......@@ -1336,14 +1566,14 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Tanh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::TopK:
......@@ -1351,9 +1581,9 @@ private:
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64)
{
reference::topk<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
static_cast<T*>(out[1]),
reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
out[1]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
......@@ -1362,9 +1592,9 @@ private:
}
else if (node.get_output_element_type(0) == element::i32)
{
reference::topk<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
static_cast<T*>(out[1]),
reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
out[1]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
......@@ -1377,7 +1607,12 @@ private:
}
break;
}
default: throw unsupported_op("Unsupported op '" + node.description() + "'");
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
case OP_TYPEID::Tile:
case OP_TYPEID::DynReplaceSlice:
throw unsupported_op("Unsupported op '" + node.description() + "'");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
......
......@@ -140,6 +140,91 @@ namespace ngraph
}
}
template <typename T>
void broadcast_5d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[5];
size_t* out_index = 0;
for (size_t i = 0; i < 5; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] *
out_shape[4] +
index[1] * out_shape[2] * out_shape[3] * out_shape[4] +
index[2] * out_shape[3] * out_shape[4] +
index[3] * out_shape[4] + index[4]] = in[*out_index];
}
}
}
}
}
}
template <typename T>
void broadcast_6d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[6];
size_t* out_index = 0;
for (size_t i = 0; i < 6; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
out[index[0] * out_shape[1] * out_shape[2] *
out_shape[3] * out_shape[4] * out_shape[5] +
index[1] * out_shape[2] * out_shape[3] *
out_shape[4] * out_shape[5] +
index[2] * out_shape[3] * out_shape[4] *
out_shape[5] +
index[3] * out_shape[4] * out_shape[5] +
index[4] * out_shape[5] + index[5]] =
in[*out_index];
}
}
}
}
}
}
}
template <typename T>
void broadcast(const T* in,
T* out,
......@@ -167,6 +252,16 @@ namespace ngraph
case 4:
broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 5:
broadcast_5d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 6:
broadcast_6d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
}
}
else
......
......@@ -244,10 +244,7 @@ namespace ngraph
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default:
NGRAPH_INFO << "reference::reshape";
reference::reshape(in, out, in_shape, in_axis_order, out_shape);
break;
default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
}
}
}
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void result(const T* arg, T* out, size_t count)
{
memcpy(out, arg, sizeof(T) * count);
}
}
}
}
}
......@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; }
std::shared_ptr<const Node> get_node() const { return m_node; }
ngraph::runtime::gcpu::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment