Unverified Commit 6f511762 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Interpreter rework (#2030)

* all tests passing

* rename a few vars to be consistent with new tensor names
parent b52a7798
......@@ -24,6 +24,7 @@
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -31,6 +32,8 @@ using namespace ngraph;
using descriptor::layout::DenseTensorLayout;
const int runtime::interpreter::INTBackend::m_alignment = 64;
extern "C" const char* get_ngraph_version_string()
{
return NGRAPH_VERSION;
......@@ -63,8 +66,12 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function)
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(m_alignment);
pass_manager.run_passes(function);
size_t memory_pool_size = function->get_temporary_pool_size();
instance.m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, m_alignment));
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
instance.m_wrapped_nodes.emplace_back(node);
......@@ -84,32 +91,36 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
FunctionInstance& instance = m_function_map[function];
// convert inputs to HostTensor
vector<shared_ptr<runtime::HostTensor>> func_inputs;
for (auto tv : inputs)
vector<void*> func_inputs;
vector<shared_ptr<runtime::HostTensor>> htv_inputs;
for (auto tensor : inputs)
{
func_inputs.push_back(static_pointer_cast<runtime::HostTensor>(tv));
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_inputs.push_back(static_cast<void*>(host_tensor->get_data_ptr()));
htv_inputs.push_back(host_tensor);
}
if (instance.m_nan_check_enabled)
{
perform_nan_check(func_inputs);
perform_nan_check(htv_inputs);
}
// convert outputs to HostTensor
vector<shared_ptr<runtime::HostTensor>> func_outputs;
for (auto tv : outputs)
vector<void*> func_outputs;
for (auto tensor : outputs)
{
func_outputs.push_back(static_pointer_cast<runtime::HostTensor>(tv));
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_outputs.push_back(static_cast<void*>(host_tensor->get_data_ptr()));
}
// map function params -> HostTensor
unordered_map<descriptor::Tensor*, shared_ptr<runtime::HostTensor>> tensor_map;
unordered_map<descriptor::Tensor*, void*> tensor_map;
size_t input_count = 0;
for (auto param : function->get_parameters())
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
descriptor::Tensor* tv = param->get_output_tensor_ptr(i).get();
tensor_map.insert({tv, func_inputs[input_count++]});
descriptor::Tensor* tensor = param->get_output_tensor_ptr(i).get();
tensor_map.insert({tensor, func_inputs[input_count++]});
}
}
......@@ -121,8 +132,8 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{
throw ngraph_error("One of function's outputs isn't op::Result");
}
descriptor::Tensor* tv = output->get_output_tensor_ptr(0).get();
tensor_map.insert({tv, func_outputs[output_count]});
descriptor::Tensor* tensor = output->get_output_tensor_ptr(0).get();
tensor_map.insert({tensor, func_outputs[output_count]});
}
// for each ordered op in the graph
......@@ -134,35 +145,42 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{
continue;
}
if (type_id == OP_TYPEID::Constant)
{
const op::Constant* c = static_cast<const op::Constant*>(op);
descriptor::Tensor* tensor = op->get_output_tensor_ptr(0).get();
tensor_map.insert({tensor, const_cast<void*>(c->get_data_ptr())});
continue;
}
// get op inputs from map
vector<shared_ptr<runtime::HostTensor>> op_inputs;
vector<const void*> op_inputs;
for (const descriptor::Input& input : op->get_inputs())
{
descriptor::Tensor* tv = input.get_output().get_tensor_ptr().get();
op_inputs.push_back(tensor_map.at(tv));
descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get();
op_inputs.push_back(tensor_map.at(tensor));
}
// get op outputs from map or create
vector<shared_ptr<runtime::HostTensor>> op_outputs;
vector<void*> op_outputs;
vector<shared_ptr<runtime::HostTensor>> htv_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i)
{
descriptor::Tensor* tv = op->get_output_tensor_ptr(i).get();
shared_ptr<runtime::HostTensor> htv;
auto it = tensor_map.find(tv);
descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get();
void* host_tensor = nullptr;
auto it = tensor_map.find(tensor);
if (it == tensor_map.end())
{
// the output tensor is not in the tensor map so create a new tensor
const Shape& shape = op->get_output_shape(i);
const element::Type& type = op->get_output_element_type(i);
string name = op->get_output_tensor(i).get_name();
htv = make_shared<runtime::HostTensor>(type, shape, name);
tensor_map.insert({tv, htv});
auto offset = op->get_output_tensor(i).get_pool_offset();
host_tensor = instance.get_temporary_pointer(offset);
tensor_map.insert({tensor, host_tensor});
}
else
{
htv = it->second;
host_tensor = it->second;
}
op_outputs.push_back(htv);
op_outputs.push_back(host_tensor);
htv_outputs.push_back(make_shared<runtime::HostTensor>(
tensor->get_element_type(), tensor->get_shape(), host_tensor));
}
// get op type
......@@ -202,20 +220,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
if (instance.m_nan_check_enabled)
{
perform_nan_check(op_outputs, op);
}
// delete any obsolete tensors
for (const descriptor::Tensor* t : op->liveness_free_list)
{
for (auto it = tensor_map.begin(); it != tensor_map.end(); ++it)
{
if (it->second->get_name() == t->get_name())
{
tensor_map.erase(it);
break;
}
}
perform_nan_check(htv_outputs, op);
}
}
......@@ -224,8 +229,8 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs,
const vector<shared_ptr<HostTensor>>& inputs,
const vector<void*>& outputs,
const vector<const void*>& inputs,
FunctionInstance& instance)
{
if (type == element::boolean)
......@@ -307,17 +312,17 @@ vector<runtime::PerformanceCounter>
return rc;
}
void runtime::interpreter::INTBackend::perform_nan_check(const vector<shared_ptr<HostTensor>>& tvs,
const Node* op)
void runtime::interpreter::INTBackend::perform_nan_check(
const vector<shared_ptr<HostTensor>>& tensors, const Node* op)
{
size_t arg_number = 1;
for (shared_ptr<HostTensor> tv : tvs)
for (const shared_ptr<HostTensor>& tensor : tensors)
{
const element::Type& type = tv->get_element_type();
const element::Type& type = tensor->get_element_type();
if (type == element::f32)
{
const float* data = tv->get_data_ptr<float>();
for (size_t i = 0; i < tv->get_element_count(); i++)
const float* data = tensor->get_data_ptr<float>();
for (size_t i = 0; i < tensor->get_element_count(); i++)
{
if (std::isnan(data[i]))
{
......@@ -335,8 +340,8 @@ void runtime::interpreter::INTBackend::perform_nan_check(const vector<shared_ptr
}
else if (type == element::f64)
{
const double* data = tv->get_data_ptr<double>();
for (size_t i = 0; i < tv->get_element_count(); i++)
const double* data = tensor->get_data_ptr<double>();
for (size_t i = 0; i < tensor->get_element_count(); i++)
{
if (std::isnan(data[i]))
{
......
......@@ -54,6 +54,7 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
......@@ -165,6 +166,7 @@ public:
bool is_supported(const Node& node) const override { return true; }
private:
static const int m_alignment;
class FunctionInstance
{
public:
......@@ -173,6 +175,9 @@ private:
bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unique_ptr<AlignedBuffer> m_temporary_memory;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
};
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
......@@ -181,14 +186,14 @@ private:
void generate_calls(const element::Type& type,
const NodeWrapper& op,
const std::vector<std::shared_ptr<HostTensor>>& outputs,
const std::vector<std::shared_ptr<HostTensor>>& inputs,
const std::vector<void*>& outputs,
const std::vector<const void*>& inputs,
FunctionInstance& instance);
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args,
const std::vector<void*>& out,
const std::vector<const void*>& args,
FunctionInstance& instance)
{
const Node& node = node_wrapper.get_node();
......@@ -205,58 +210,63 @@ private:
{
case OP_TYPEID::Abs:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Acos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Add:
{
reference::add<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_element_type(),
static_cast<int>(args[0]->get_element_count()));
reference::allreduce<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))));
#endif
break;
}
case OP_TYPEID::And:
{
reference::logical_and(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::ArgMin:
{
const op::ArgMin* argmin = static_cast<const op::ArgMin*>(&node);
if (out[0]->get_element_type() == element::i64)
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmin<T, int64_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
}
else if (out[0]->get_element_type() == element::i32)
else if (element_type == element::i32)
{
reference::argmin<T, int32_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
}
else
......@@ -268,20 +278,21 @@ private:
case OP_TYPEID::ArgMax:
{
const op::ArgMax* argmax = static_cast<const op::ArgMax*>(&node);
if (out[0]->get_element_type() == element::i64)
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmax<T, int64_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
}
else if (out[0]->get_element_type() == element::i32)
else if (element_type == element::i32)
{
reference::argmax<T, int32_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
}
else
......@@ -292,24 +303,26 @@ private:
}
case OP_TYPEID::Asin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Atan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::AvgPool:
{
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::avg_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
avg_pool->get_padding_below(),
......@@ -327,8 +340,9 @@ private:
const op::GetOutputElement* get_output_element =
static_cast<const op::GetOutputElement*>(&node);
size_t n = get_output_element->get_n();
size_t num_bytes = out[0]->get_element_count() * out[0]->get_element_type().size();
std::memcpy(out[0]->get_data_ptr(), args[n]->get_data_ptr(), num_bytes);
size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes);
break;
}
case OP_TYPEID::BatchNormTraining:
......@@ -337,26 +351,25 @@ private:
static_cast<const ngraph::op::BatchNormTraining*>(&node);
if (bn->get_output_size() == 3)
{
reference::batch_norm_three_outputs<T>(
bn->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
reinterpret_cast<T*>(out[1]->get_data_ptr()),
reinterpret_cast<T*>(out[2]->get_data_ptr()),
args[2]->get_shape());
reference::batch_norm_three_outputs<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
node.get_input_shape(2));
}
else
{
reference::batch_norm_one_output<T>(bn->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(args[3]->get_data_ptr()),
reinterpret_cast<T*>(args[4]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[2]->get_shape());
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<T*>(out[0]),
node.get_input_shape(2));
}
break;
}
......@@ -365,13 +378,13 @@ private:
const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_one_output<T>(bn->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(args[3]->get_data_ptr()),
reinterpret_cast<T*>(args[4]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[2]->get_shape());
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<T*>(out[0]),
node.get_input_shape(2));
break;
}
case OP_TYPEID::BatchNormTrainingBackprop:
......@@ -379,25 +392,25 @@ private:
const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(args[3]->get_data_ptr()),
reinterpret_cast<T*>(args[4]->get_data_ptr()),
reinterpret_cast<T*>(args[5]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
reinterpret_cast<T*>(out[1]->get_data_ptr()),
reinterpret_cast<T*>(out[2]->get_data_ptr()),
args[2]->get_shape());
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<const T*>(args[5]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
node.get_input_shape(2));
break;
}
case OP_TYPEID::AvgPoolBackprop:
{
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
apb->get_window_shape(),
apb->get_window_movement_strides(),
apb->get_padding_below(),
......@@ -408,11 +421,11 @@ private:
case OP_TYPEID::Broadcast:
{
const op::Broadcast* broadcast = static_cast<const op::Broadcast*>(&node);
Shape in_shape = args[0]->get_shape();
Shape out_shape = out[0]->get_shape();
Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
reference::broadcast<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
in_shape,
out_shape,
broadcast_axes);
......@@ -420,8 +433,9 @@ private:
}
case OP_TYPEID::Ceiling:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Concat:
......@@ -429,94 +443,82 @@ private:
const op::Concat* concat = static_cast<const op::Concat*>(&node);
std::vector<const T*> in_args;
std::vector<Shape> in_shapes;
for (std::shared_ptr<HostTensor> arg : args)
for (size_t i = 0; i < node.get_input_size(); i++)
{
in_args.push_back(arg->get_data_ptr<T>());
in_shapes.push_back(arg->get_shape());
in_args.push_back(static_cast<const T*>(args[i]));
in_shapes.push_back(node.get_input_shape(i));
}
reference::concat<T>(in_args,
out[0]->get_data_ptr<T>(),
static_cast<T*>(out[0]),
in_shapes,
out[0]->get_shape(),
node.get_output_shape(0),
concat->get_concatenation_axis());
break;
}
case OP_TYPEID::Constant:
{
const op::Constant* c = static_cast<const op::Constant*>(&node);
reference::constant<T>(
c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
// Constant is handled in the main loop
break;
}
case OP_TYPEID::Convert:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
size_t element_count = shape_size(node.get_output_shape(0));
if (type == element::boolean)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count);
}
else if (type == element::f32)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<float>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count);
}
else if (type == element::f64)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<double>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count);
}
else if (type == element::i8)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int8_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count);
}
else if (type == element::i16)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int16_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count);
}
else if (type == element::i32)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count);
}
else if (type == element::i64)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count);
}
else if (type == element::u8)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<uint8_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count);
}
else if (type == element::u16)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<uint16_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count);
}
else if (type == element::u32)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<uint32_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count);
}
else if (type == element::u64)
{
reference::convert<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<uint64_t>(),
out[0]->get_element_count());
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count);
}
else
{
......@@ -529,12 +531,12 @@ private:
case OP_TYPEID::Convolution:
{
const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides(),
c->get_window_dilation_strides(),
c->get_padding_below(),
......@@ -553,12 +555,12 @@ private:
{
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
......@@ -578,12 +580,12 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(args[1]->get_data_ptr<T>(),
args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[1]->get_shape(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::convolution<T>(static_cast<const T*>(args[1]),
static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
node.get_input_shape(0),
node.get_output_shape(0),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
......@@ -600,14 +602,16 @@ private:
}
case OP_TYPEID::Cos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Cosh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Dequantize:
......@@ -617,22 +621,22 @@ private:
if (type == element::f32)
{
reference::dequantize<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<float>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<float>(),
args[0]->get_shape(),
args[1]->get_shape(),
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const float*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<float*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else if (type == element::f64)
{
reference::dequantize<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<double>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<double>(),
args[0]->get_shape(),
args[1]->get_shape(),
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const double*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<double*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else
......@@ -646,43 +650,47 @@ private:
}
case OP_TYPEID::Divide:
{
reference::divide<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Dot:
{
const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
reference::dot(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
dot->get_reduction_axes_count());
break;
}
case OP_TYPEID::Equal:
{
reference::equal<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Exp:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Floor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::FunctionCall:
......@@ -690,15 +698,24 @@ private:
std::shared_ptr<Function> function = node.get_functions()[0];
std::vector<std::shared_ptr<runtime::Tensor>> outputs;
for (auto tv : out)
for (size_t i = 0; i < function->get_output_size(); i++)
{
outputs.push_back(std::static_pointer_cast<runtime::Tensor>(tv));
element::Type et = function->get_output_element_type(i);
Shape shape = function->get_output_shape(i);
auto host_tensor = std::make_shared<HostTensor>(et, shape, out[i]);
outputs.push_back(std::static_pointer_cast<runtime::Tensor>(host_tensor));
}
std::vector<std::shared_ptr<runtime::Tensor>> inputs;
for (auto tv : args)
auto parameters = function->get_parameters();
for (size_t i = 0; i < parameters.size(); i++)
{
inputs.push_back(std::static_pointer_cast<runtime::Tensor>(tv));
auto parameter = parameters[i];
element::Type et = parameter->get_element_type();
Shape shape = parameter->get_shape();
auto host_tensor =
std::make_shared<HostTensor>(et, shape, const_cast<void*>(args[i]));
inputs.push_back(std::static_pointer_cast<runtime::Tensor>(host_tensor));
}
call(function, outputs, inputs);
......@@ -706,48 +723,53 @@ private:
}
case OP_TYPEID::Greater:
{
reference::greater<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::GreaterEq:
{
reference::greater_eq<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Less:
{
reference::less<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::LessEq:
{
reference::less_eq<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Log:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
reference::lrn<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
lrn->get_alpha(),
lrn->get_beta(),
lrn->get_bias(),
......@@ -757,29 +779,30 @@ private:
case OP_TYPEID::Max:
{
const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::max<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
max->get_reduction_axes());
break;
}
case OP_TYPEID::Maximum:
{
reference::maximum<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::MaxPool:
{
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::max_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
max_pool->get_window_shape(),
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
......@@ -791,11 +814,11 @@ private:
const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[1]->get_shape(),
out[0]->get_shape(),
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
node.get_output_shape(0),
max_pool_backprop->get_window_shape(),
max_pool_backprop->get_window_movement_strides(),
max_pool_backprop->get_padding_below(),
......@@ -805,65 +828,71 @@ private:
case OP_TYPEID::Min:
{
const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::min<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
min->get_reduction_axes());
break;
}
case OP_TYPEID::Minimum:
{
reference::minimum<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Multiply:
{
reference::multiply<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Negative:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::NotEqual:
{
reference::not_equal<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
element_count);
break;
}
case OP_TYPEID::OneHot:
{
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::one_hot<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
oh->get_one_hot_axis());
break;
}
case OP_TYPEID::Or:
{
reference::logical_or(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Parameter: break;
......@@ -871,9 +900,9 @@ private:
{
const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
reference::pad(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
pad->get_padding_below(),
......@@ -883,19 +912,20 @@ private:
}
case OP_TYPEID::Power:
{
reference::power<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Product:
{
const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::product<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
product->get_reduction_axes());
break;
}
......@@ -906,23 +936,23 @@ private:
if (type == element::u8)
{
reference::quantize<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
args[2]->get_data_ptr<uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
args[0]->get_shape(),
args[1]->get_shape(),
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const uint8_t*>(args[2]),
static_cast<uint8_t*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i8)
{
reference::quantize<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
args[2]->get_data_ptr<int8_t>(),
out[0]->get_data_ptr<int8_t>(),
args[0]->get_shape(),
args[1]->get_shape(),
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int8_t*>(args[2]),
static_cast<int8_t*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
......@@ -942,20 +972,18 @@ private:
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::make_shared<HostTensor>(
node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_temp_x");
node.get_inputs().at(0).get_element_type(), Shape{}, &x, "reduce_temp_x");
auto ty = std::make_shared<HostTensor>(
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_temp_y");
node.get_inputs().at(1).get_element_type(), Shape{}, &y, "reduce_temp_y");
auto tr = std::make_shared<HostTensor>(
node.get_output_element_type(0), Shape{}, "reduce_temp_r");
*(tx->get_data_ptr<T>()) = x;
*(ty->get_data_ptr<T>()) = y;
call(reduction_function, {tr}, {tx, ty});
return *(tr->get_data_ptr<T>());
};
reference::reduce(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
reference::reduce(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
reduce->get_reduction_axes(),
......@@ -968,21 +996,23 @@ private:
std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::make_shared<HostTensor>(
node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_window_temp_x");
auto ty = std::make_shared<HostTensor>(
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_window_temp_y");
auto tx = std::make_shared<HostTensor>(node.get_inputs().at(0).get_element_type(),
Shape{},
&x,
"reduce_window_temp_x");
auto ty = std::make_shared<HostTensor>(node.get_inputs().at(1).get_element_type(),
Shape{},
&y,
"reduce_window_temp_y");
auto tr = std::make_shared<HostTensor>(
node.get_output_element_type(0), Shape{}, "reduce_window_temp_r");
*(tx->get_data_ptr<T>()) = x;
*(ty->get_data_ptr<T>()) = y;
call(reduction_function, {tr}, {tx, ty});
return *(tr->get_data_ptr<T>());
};
reference::reduce_window(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
reference::reduce_window(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
f,
......@@ -992,56 +1022,58 @@ private:
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::ReluBackprop:
{
reference::relu_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::ReplaceSlice:
{
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[1]->get_shape(),
reference::replace_slice<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
node.get_input_shape(1),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
out[0]->get_shape());
node.get_output_shape(0));
break;
}
case OP_TYPEID::Reshape:
{
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
reference::reshape(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
reference::reshape(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
reshape->get_input_order(),
out[0]->get_shape());
node.get_output_shape(0));
break;
}
case OP_TYPEID::Result:
{
const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
reference::result(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
shape_size(res->get_shape()));
break;
}
case OP_TYPEID::Reverse:
{
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::reverse(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
reverse->get_reversed_axes());
break;
}
......@@ -1049,14 +1081,14 @@ private:
{
const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
if (args[1]->get_element_type() == element::i32)
if (node.get_input_element_type(1) == element::i32)
{
reference::reverse_sequence<T, int>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
reverse->get_batch_axis(),
reverse->get_sequence_axis(),
args[1]->get_data_ptr<int>());
static_cast<const int32_t*>(args[1]));
}
else
{
......@@ -1066,11 +1098,12 @@ private:
}
case OP_TYPEID::Select:
{
reference::select<T>(args[0]->get_data_ptr<char>(),
args[1]->get_data_ptr<T>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::SelectAndScatter:
......@@ -1083,13 +1116,11 @@ private:
std::function<bool(T, T)> f_selection = [this, &node, selection_function](T x,
T y) -> bool {
auto tx = std::make_shared<runtime::HostTensor>(
node.get_inputs().at(0).get_element_type(), Shape{}, "selection_temp_x");
node.get_inputs().at(0).get_element_type(), Shape{}, &x, "selection_temp_x");
auto ty = std::make_shared<runtime::HostTensor>(
node.get_inputs().at(1).get_element_type(), Shape{}, "selection_temp_y");
node.get_inputs().at(1).get_element_type(), Shape{}, &y, "selection_temp_y");
auto tr = std::make_shared<runtime::HostTensor>(
element::boolean, Shape{}, "selection_temp_r");
*(tx->get_data_ptr<T>()) = x;
*(ty->get_data_ptr<T>()) = y;
call(selection_function, {tr}, {tx, ty});
return *(tr->get_data_ptr<char>());
};
......@@ -1098,24 +1129,22 @@ private:
select_and_scatter->get_functions()[1];
std::function<T(T, T)> f_scatter = [this, &node, scatter_function](T x, T y) -> T {
auto tx = std::make_shared<runtime::HostTensor>(
node.get_inputs().at(0).get_element_type(), Shape{}, "scatter_temp_x");
node.get_inputs().at(0).get_element_type(), Shape{}, &x, "scatter_temp_x");
auto ty = std::make_shared<runtime::HostTensor>(
node.get_inputs().at(1).get_element_type(), Shape{}, "scatter_temp_y");
node.get_inputs().at(1).get_element_type(), Shape{}, &y, "scatter_temp_y");
auto tr = std::make_shared<runtime::HostTensor>(
node.get_output_element_type(0), Shape{}, "scatter_temp_r");
*(tx->get_data_ptr<T>()) = x;
*(ty->get_data_ptr<T>()) = y;
call(scatter_function, {tr}, {tx, ty});
return *(tr->get_data_ptr<T>());
};
reference::select_and_scatter<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
reference::select_and_scatter<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
f_selection,
f_scatter,
select_and_scatter->get_window_shape(),
......@@ -1124,116 +1153,125 @@ private:
}
case OP_TYPEID::Sigmoid:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::SigmoidBackprop:
{
reference::sigmoid_backprop<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Sign:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Sin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Sinh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Slice:
{
const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
reference::slice<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
out[0]->get_shape());
node.get_output_shape(0));
break;
}
case OP_TYPEID::Softmax:
{
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_shape(),
reference::softmax<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_output_shape(0),
softmax->get_axes());
break;
}
case OP_TYPEID::Sqrt:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
}
case OP_TYPEID::Subtract:
{
reference::subtract<T>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count);
break;
}
case OP_TYPEID::Sum:
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::sum<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
node.get_input_shape(0),
node.get_output_shape(0),
sum->get_reduction_axes());
break;
}
case OP_TYPEID::Tan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Tanh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::TopK:
{
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (out[0]->get_element_type() == element::i64)
if (node.get_output_element_type(0) == element::i64)
{
reference::topk<T, int64_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int64_t>(),
out[1]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::topk<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
static_cast<T*>(out[1]),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
}
else if (out[0]->get_element_type() == element::i32)
else if (node.get_output_element_type(0) == element::i32)
{
reference::topk<T, int32_t>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<int32_t>(),
out[1]->get_data_ptr<T>(),
args[0]->get_shape(),
out[0]->get_shape(),
reference::topk<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
static_cast<T*>(out[1]),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max());
......
......@@ -34,7 +34,7 @@ namespace ngraph
}
}
template <typename T>
void relu_backprop(const T* arg, T* delta_arg, T* out, size_t count)
void relu_backprop(const T* arg, const T* delta_arg, T* out, size_t count)
{
T zero = 0;
for (size_t i = 0; i < count; i++)
......
......@@ -34,7 +34,7 @@ namespace ngraph
const Shape& arg_shape,
size_t batch_axis,
size_t sequence_axis,
U* sequence_lengths)
const U* sequence_lengths)
{
CoordinateTransform input_transform(arg_shape);
for (const Coordinate& in_coord : input_transform)
......
......@@ -37,7 +37,7 @@ namespace ngraph
}
template <typename T>
void sigmoid_backprop(const T* arg, T* delta_arg, T* out, size_t count)
void sigmoid_backprop(const T* arg, const T* delta_arg, T* out, size_t count)
{
T exp_value;
T func_x;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment