Unverified Commit 5cfc7e92 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

pass FunctionInstance to interpreter op execution (#1947)

parent 8ef1ec04
...@@ -200,7 +200,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -200,7 +200,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{ {
instance.m_timer_map[op].start(); instance.m_timer_map[op].start();
} }
generate_calls(type, wrapped, op_outputs, op_inputs); generate_calls(type, wrapped, op_outputs, op_inputs, instance);
if (instance.m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
instance.m_timer_map[op].stop(); instance.m_timer_map[op].stop();
...@@ -230,51 +230,52 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function, ...@@ -230,51 +230,52 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void runtime::interpreter::INTBackend::generate_calls(const element::Type& type, void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs, const vector<shared_ptr<HostTensor>>& outputs,
const vector<shared_ptr<HostTensor>>& inputs) const vector<shared_ptr<HostTensor>>& inputs,
FunctionInstance& instance)
{ {
if (type == element::boolean) if (type == element::boolean)
{ {
op_engine<char>(op, outputs, inputs); op_engine<char>(op, outputs, inputs, instance);
} }
else if (type == element::f32) else if (type == element::f32)
{ {
op_engine<float>(op, outputs, inputs); op_engine<float>(op, outputs, inputs, instance);
} }
else if (type == element::f64) else if (type == element::f64)
{ {
op_engine<double>(op, outputs, inputs); op_engine<double>(op, outputs, inputs, instance);
} }
else if (type == element::i8) else if (type == element::i8)
{ {
op_engine<int8_t>(op, outputs, inputs); op_engine<int8_t>(op, outputs, inputs, instance);
} }
else if (type == element::i16) else if (type == element::i16)
{ {
op_engine<int16_t>(op, outputs, inputs); op_engine<int16_t>(op, outputs, inputs, instance);
} }
else if (type == element::i32) else if (type == element::i32)
{ {
op_engine<int32_t>(op, outputs, inputs); op_engine<int32_t>(op, outputs, inputs, instance);
} }
else if (type == element::i64) else if (type == element::i64)
{ {
op_engine<int64_t>(op, outputs, inputs); op_engine<int64_t>(op, outputs, inputs, instance);
} }
else if (type == element::u8) else if (type == element::u8)
{ {
op_engine<uint8_t>(op, outputs, inputs); op_engine<uint8_t>(op, outputs, inputs, instance);
} }
else if (type == element::u16) else if (type == element::u16)
{ {
op_engine<uint16_t>(op, outputs, inputs); op_engine<uint16_t>(op, outputs, inputs, instance);
} }
else if (type == element::u32) else if (type == element::u32)
{ {
op_engine<uint32_t>(op, outputs, inputs); op_engine<uint32_t>(op, outputs, inputs, instance);
} }
else if (type == element::u64) else if (type == element::u64)
{ {
op_engine<uint64_t>(op, outputs, inputs); op_engine<uint64_t>(op, outputs, inputs, instance);
} }
else else
{ {
......
...@@ -179,12 +179,14 @@ private: ...@@ -179,12 +179,14 @@ private:
void generate_calls(const element::Type& type, void generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const std::vector<std::shared_ptr<HostTensor>>& outputs, const std::vector<std::shared_ptr<HostTensor>>& outputs,
const std::vector<std::shared_ptr<HostTensor>>& inputs); const std::vector<std::shared_ptr<HostTensor>>& inputs,
FunctionInstance& instance);
template <typename T> template <typename T>
void op_engine(const NodeWrapper& node_wrapper, void op_engine(const NodeWrapper& node_wrapper,
const std::vector<std::shared_ptr<HostTensor>>& out, const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args) const std::vector<std::shared_ptr<HostTensor>>& args,
FunctionInstance& instance)
{ {
const Node& node = node_wrapper.get_node(); const Node& node = node_wrapper.get_node();
std::string node_op = node.description(); std::string node_op = node.description();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment