Unverified Commit c9a9c154 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Change interpreter to use dynamic memory allocation (#2301)

* builds

* pass unit test

* cleanup
parent 917efb94
...@@ -74,12 +74,8 @@ runtime::Handle runtime::interpreter::INTBackend::compile(shared_ptr<Function> f ...@@ -74,12 +74,8 @@ runtime::Handle runtime::interpreter::INTBackend::compile(shared_ptr<Function> f
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(get_alignment());
pass_manager.run_passes(function); pass_manager.run_passes(function);
size_t memory_pool_size = function->get_temporary_pool_size();
instance.m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, get_alignment()));
for (const shared_ptr<Node>& node : function->get_ordered_ops()) for (const shared_ptr<Node>& node : function->get_ordered_ops())
{ {
instance.m_wrapped_nodes.emplace_back(node); instance.m_wrapped_nodes.emplace_back(node);
...@@ -105,29 +101,27 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -105,29 +101,27 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
} }
// convert inputs to HostTensor // convert inputs to HostTensor
vector<void*> func_inputs; vector<shared_ptr<HostTensor>> func_inputs;
vector<shared_ptr<runtime::HostTensor>> htv_inputs;
for (auto tensor : inputs) for (auto tensor : inputs)
{ {
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor); auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_inputs.push_back(static_cast<void*>(host_tensor->get_data_ptr())); func_inputs.push_back(host_tensor);
htv_inputs.push_back(host_tensor);
} }
if (instance.m_nan_check_enabled) if (instance.m_nan_check_enabled)
{ {
perform_nan_check(htv_inputs); perform_nan_check(func_inputs);
} }
// convert outputs to HostTensor // convert outputs to HostTensor
vector<void*> func_outputs; vector<shared_ptr<HostTensor>> func_outputs;
for (auto tensor : outputs) for (auto tensor : outputs)
{ {
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor); auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_outputs.push_back(static_cast<void*>(host_tensor->get_data_ptr())); func_outputs.push_back(host_tensor);
} }
// map function params -> HostTensor // map function params -> HostTensor
unordered_map<descriptor::Tensor*, void*> tensor_map; unordered_map<descriptor::Tensor*, shared_ptr<HostTensor>> tensor_map;
size_t input_count = 0; size_t input_count = 0;
for (auto param : function->get_parameters()) for (auto param : function->get_parameters())
{ {
...@@ -159,15 +153,9 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -159,15 +153,9 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{ {
continue; continue;
} }
if (type_id == OP_TYPEID::Constant)
{
const op::Constant* c = static_cast<const op::Constant*>(op);
descriptor::Tensor* tensor = op->get_output_tensor_ptr(0).get();
tensor_map.insert({tensor, const_cast<void*>(c->get_data_ptr())});
continue;
}
// get op inputs from map // get op inputs from map
vector<const void*> op_inputs; vector<shared_ptr<HostTensor>> op_inputs;
for (const descriptor::Input& input : op->get_inputs()) for (const descriptor::Input& input : op->get_inputs())
{ {
descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get(); descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get();
...@@ -175,17 +163,18 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -175,17 +163,18 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
} }
// get op outputs from map or create // get op outputs from map or create
vector<void*> op_outputs; vector<shared_ptr<HostTensor>> op_outputs;
vector<shared_ptr<runtime::HostTensor>> htv_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i) for (size_t i = 0; i < op->get_output_size(); ++i)
{ {
descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get(); descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get();
void* host_tensor = nullptr; shared_ptr<HostTensor> host_tensor;
auto it = tensor_map.find(tensor); auto it = tensor_map.find(tensor);
if (it == tensor_map.end()) if (it == tensor_map.end())
{ {
auto offset = op->get_output_tensor(i).get_pool_offset(); const Shape& shape = op->get_output_shape(i);
host_tensor = instance.get_temporary_pointer(offset); const element::Type& type = op->get_output_element_type(i);
string name = op->get_output_tensor(i).get_name();
host_tensor = make_shared<runtime::HostTensor>(type, shape, name);
tensor_map.insert({tensor, host_tensor}); tensor_map.insert({tensor, host_tensor});
} }
else else
...@@ -193,8 +182,6 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -193,8 +182,6 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
host_tensor = it->second; host_tensor = it->second;
} }
op_outputs.push_back(host_tensor); op_outputs.push_back(host_tensor);
htv_outputs.push_back(make_shared<runtime::HostTensor>(
tensor->get_element_type(), tensor->get_shape(), host_tensor));
} }
// get op type // get op type
...@@ -235,7 +222,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -235,7 +222,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
} }
if (instance.m_nan_check_enabled) if (instance.m_nan_check_enabled)
{ {
perform_nan_check(htv_outputs, op); perform_nan_check(op_outputs, op);
} }
} }
...@@ -244,24 +231,34 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -244,24 +231,34 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls(const element::Type& type, void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const vector<void*>& outputs, const vector<shared_ptr<HostTensor>>& outputs,
const vector<const void*>& inputs, const vector<shared_ptr<HostTensor>>& inputs,
FunctionInstance& instance) FunctionInstance& instance)
{ {
vector<void*> out;
vector<const void*> in;
for (auto t : outputs)
{
out.push_back(t->get_data_ptr());
}
for (auto t : inputs)
{
in.push_back(t->get_data_ptr());
}
stringstream ss; stringstream ss;
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
case element::Type_t::boolean: op_engine<char>(op, outputs, inputs, instance); break; case element::Type_t::boolean: op_engine<char>(op, out, in, instance); break;
case element::Type_t::f32: op_engine<float>(op, outputs, inputs, instance); break; case element::Type_t::f32: op_engine<float>(op, out, in, instance); break;
case element::Type_t::f64: op_engine<double>(op, outputs, inputs, instance); break; case element::Type_t::f64: op_engine<double>(op, out, in, instance); break;
case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs, instance); break; case element::Type_t::i8: op_engine<int8_t>(op, out, in, instance); break;
case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs, instance); break; case element::Type_t::i16: op_engine<int16_t>(op, out, in, instance); break;
case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs, instance); break; case element::Type_t::i32: op_engine<int32_t>(op, out, in, instance); break;
case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs, instance); break; case element::Type_t::i64: op_engine<int64_t>(op, out, in, instance); break;
case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs, instance); break; case element::Type_t::u8: op_engine<uint8_t>(op, out, in, instance); break;
case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs, instance); break; case element::Type_t::u16: op_engine<uint16_t>(op, out, in, instance); break;
case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs, instance); break; case element::Type_t::u32: op_engine<uint32_t>(op, out, in, instance); break;
case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs, instance); break; case element::Type_t::u64: op_engine<uint64_t>(op, out, in, instance); break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
......
...@@ -186,9 +186,6 @@ private: ...@@ -186,9 +186,6 @@ private:
std::unordered_map<const Node*, stopwatch> m_timer_map; std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes; std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states; std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::shared_ptr<AlignedBuffer> m_temporary_memory;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
}; };
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map; std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
std::set<std::string> m_unsupported_op_name_list; std::set<std::string> m_unsupported_op_name_list;
...@@ -198,8 +195,8 @@ private: ...@@ -198,8 +195,8 @@ private:
void generate_calls(const element::Type& type, void generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const std::vector<void*>& outputs, const std::vector<std::shared_ptr<HostTensor>>& outputs,
const std::vector<const void*>& inputs, const std::vector<std::shared_ptr<HostTensor>>& inputs,
FunctionInstance& instance); FunctionInstance& instance);
template <typename T> template <typename T>
...@@ -487,7 +484,9 @@ private: ...@@ -487,7 +484,9 @@ private:
} }
case OP_TYPEID::Constant: case OP_TYPEID::Constant:
{ {
// Constant is handled in the main loop const op::Constant* c = static_cast<const op::Constant*>(&node);
size_t element_count = shape_size(node.get_output_shape(0));
reference::constant<T>(c->get_data_ptr<T>(), static_cast<T*>(out[0]), element_count);
break; break;
} }
case OP_TYPEID::ScalarConstantLike: break; case OP_TYPEID::ScalarConstantLike: break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment