Unverified Commit 5cfc7e92 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

pass FunctionInstance to interpreter op execution (#1947)

parent 8ef1ec04
......@@ -200,7 +200,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{
instance.m_timer_map[op].start();
}
generate_calls(type, wrapped, op_outputs, op_inputs);
generate_calls(type, wrapped, op_outputs, op_inputs, instance);
if (instance.m_performance_counters_enabled)
{
instance.m_timer_map[op].stop();
......@@ -230,51 +230,52 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs,
const vector<shared_ptr<HostTensor>>& inputs)
const vector<shared_ptr<HostTensor>>& inputs,
FunctionInstance& instance)
{
if (type == element::boolean)
{
op_engine<char>(op, outputs, inputs);
op_engine<char>(op, outputs, inputs, instance);
}
else if (type == element::f32)
{
op_engine<float>(op, outputs, inputs);
op_engine<float>(op, outputs, inputs, instance);
}
else if (type == element::f64)
{
op_engine<double>(op, outputs, inputs);
op_engine<double>(op, outputs, inputs, instance);
}
else if (type == element::i8)
{
op_engine<int8_t>(op, outputs, inputs);
op_engine<int8_t>(op, outputs, inputs, instance);
}
else if (type == element::i16)
{
op_engine<int16_t>(op, outputs, inputs);
op_engine<int16_t>(op, outputs, inputs, instance);
}
else if (type == element::i32)
{
op_engine<int32_t>(op, outputs, inputs);
op_engine<int32_t>(op, outputs, inputs, instance);
}
else if (type == element::i64)
{
op_engine<int64_t>(op, outputs, inputs);
op_engine<int64_t>(op, outputs, inputs, instance);
}
else if (type == element::u8)
{
op_engine<uint8_t>(op, outputs, inputs);
op_engine<uint8_t>(op, outputs, inputs, instance);
}
else if (type == element::u16)
{
op_engine<uint16_t>(op, outputs, inputs);
op_engine<uint16_t>(op, outputs, inputs, instance);
}
else if (type == element::u32)
{
op_engine<uint32_t>(op, outputs, inputs);
op_engine<uint32_t>(op, outputs, inputs, instance);
}
else if (type == element::u64)
{
op_engine<uint64_t>(op, outputs, inputs);
op_engine<uint64_t>(op, outputs, inputs, instance);
}
else
{
......
......@@ -179,12 +179,14 @@ private:
void generate_calls(const element::Type& type,
const NodeWrapper& op,
const std::vector<std::shared_ptr<HostTensor>>& outputs,
const std::vector<std::shared_ptr<HostTensor>>& inputs);
const std::vector<std::shared_ptr<HostTensor>>& inputs,
FunctionInstance& instance);
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args)
const std::vector<std::shared_ptr<HostTensor>>& args,
FunctionInstance& instance)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment