Commit 9530267d authored by Robert Kimball's avatar Robert Kimball

minimal change update

parent ef2e0118
...@@ -75,6 +75,8 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f ...@@ -75,6 +75,8 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
bool enable_performance_collection) bool enable_performance_collection)
{ {
{ {
FunctionInstance& instance = m_function_instance;
instance.m_is_compiled = true;
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
...@@ -83,20 +85,21 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f ...@@ -83,20 +85,21 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass_manager.run_passes(function); pass_manager.run_passes(function);
size_t memory_pool_size = function->get_temporary_pool_size(); size_t memory_pool_size = function->get_temporary_pool_size();
m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, get_alignment())); instance.m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, get_alignment()));
for (const shared_ptr<Node>& node : function->get_ordered_ops()) for (const shared_ptr<Node>& node : function->get_ordered_ops())
{ {
m_wrapped_nodes.emplace_back(node); instance.m_wrapped_nodes.emplace_back(node);
} }
} }
set_parameters_and_results(*function); set_parameters_and_results(*function);
} }
bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs, bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs) const vector<shared_ptr<runtime::Tensor>>& inputs)
{ {
FunctionInstance& instance = m_function_instance;
// convert inputs to HostTensor // convert inputs to HostTensor
vector<void*> func_inputs; vector<void*> func_inputs;
vector<shared_ptr<runtime::HostTensor>> htv_inputs; vector<shared_ptr<runtime::HostTensor>> htv_inputs;
...@@ -106,7 +109,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime:: ...@@ -106,7 +109,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
func_inputs.push_back(static_cast<void*>(host_tensor->get_data_ptr())); func_inputs.push_back(static_cast<void*>(host_tensor->get_data_ptr()));
htv_inputs.push_back(host_tensor); htv_inputs.push_back(host_tensor);
} }
if (m_nan_check_enabled) if (instance.m_nan_check_enabled)
{ {
perform_nan_check(htv_inputs); perform_nan_check(htv_inputs);
} }
...@@ -144,7 +147,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime:: ...@@ -144,7 +147,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
} }
// for each ordered op in the graph // for each ordered op in the graph
for (const NodeWrapper& wrapped : m_wrapped_nodes) for (const NodeWrapper& wrapped : instance.m_wrapped_nodes)
{ {
const Node* op = &wrapped.get_node(); const Node* op = &wrapped.get_node();
auto type_id = wrapped.get_typeid(); auto type_id = wrapped.get_typeid();
...@@ -178,7 +181,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime:: ...@@ -178,7 +181,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
if (it == tensor_map.end()) if (it == tensor_map.end())
{ {
auto offset = op->get_output_tensor(i).get_pool_offset(); auto offset = op->get_output_tensor(i).get_pool_offset();
host_tensor = get_temporary_pointer(offset); host_tensor = instance.get_temporary_pointer(offset);
tensor_map.insert({tensor, host_tensor}); tensor_map.insert({tensor, host_tensor});
} }
else else
...@@ -217,16 +220,16 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime:: ...@@ -217,16 +220,16 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
} }
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
if (m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
m_timer_map[op].start(); instance.m_timer_map[op].start();
} }
generate_calls(type, wrapped, op_outputs, op_inputs); generate_calls(type, wrapped, op_outputs, op_inputs, instance);
if (m_performance_counters_enabled) if (instance.m_performance_counters_enabled)
{ {
m_timer_map[op].stop(); instance.m_timer_map[op].stop();
} }
if (m_nan_check_enabled) if (instance.m_nan_check_enabled)
{ {
perform_nan_check(htv_outputs, op); perform_nan_check(htv_outputs, op);
} }
...@@ -238,22 +241,23 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime:: ...@@ -238,22 +241,23 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
void runtime::interpreter::INTExecutable::generate_calls(const element::Type& type, void runtime::interpreter::INTExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const vector<void*>& outputs, const vector<void*>& outputs,
const vector<const void*>& inputs) const vector<const void*>& inputs,
FunctionInstance& instance)
{ {
stringstream ss; stringstream ss;
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
case element::Type_t::boolean: op_engine<char>(op, outputs, inputs); break; case element::Type_t::boolean: op_engine<char>(op, outputs, inputs, instance); break;
case element::Type_t::f32: op_engine<float>(op, outputs, inputs); break; case element::Type_t::f32: op_engine<float>(op, outputs, inputs, instance); break;
case element::Type_t::f64: op_engine<double>(op, outputs, inputs); break; case element::Type_t::f64: op_engine<double>(op, outputs, inputs, instance); break;
case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs); break; case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs, instance); break;
case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs); break; case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs, instance); break;
case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs); break; case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs, instance); break;
case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs); break; case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs, instance); break;
case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs); break; case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs, instance); break;
case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs); break; case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs, instance); break;
case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs); break; case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs, instance); break;
case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs); break; case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs, instance); break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
...@@ -262,11 +266,18 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty ...@@ -262,11 +266,18 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
} }
} }
void runtime::interpreter::INTExecutable::set_nan_check(bool enable)
{
FunctionInstance& instance = m_function_instance;
instance.m_nan_check_enabled = enable;
}
vector<runtime::PerformanceCounter> vector<runtime::PerformanceCounter>
runtime::interpreter::INTExecutable::get_performance_data() const runtime::interpreter::INTExecutable::get_performance_data() const
{ {
vector<runtime::PerformanceCounter> rc; vector<runtime::PerformanceCounter> rc;
for (const pair<const Node*, stopwatch> p : m_timer_map) const FunctionInstance& instance = m_function_instance;
for (const pair<const Node*, stopwatch> p : instance.m_timer_map)
{ {
rc.emplace_back(p.first->get_name().c_str(), rc.emplace_back(p.first->get_name().c_str(),
p.second.get_total_microseconds(), p.second.get_total_microseconds(),
......
...@@ -180,31 +180,41 @@ public: ...@@ -180,31 +180,41 @@ public:
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs, bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override; const std::vector<std::shared_ptr<Tensor>>& intputs) override;
void set_nan_check(bool value) { m_nan_check_enabled = value; } void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override; std::vector<PerformanceCounter> get_performance_data() const override;
private: private:
int get_alignment() const { return 64; } int get_alignment() const { return 64; }
bool m_nan_check_enabled = false; class FunctionInstance
bool m_performance_counters_enabled = false; {
std::unordered_map<const Node*, stopwatch> m_timer_map; public:
std::vector<NodeWrapper> m_wrapped_nodes; bool m_is_compiled = false;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states; bool m_nan_check_enabled = false;
std::shared_ptr<AlignedBuffer> m_temporary_memory; bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::shared_ptr<AlignedBuffer> m_temporary_memory;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
} m_function_instance;
std::set<std::string> m_unsupported_op_name_list;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&, static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr); const Node* op = nullptr);
void generate_calls(const element::Type& type, void generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const std::vector<void*>& outputs, const std::vector<void*>& outputs,
const std::vector<const void*>& inputs); const std::vector<const void*>& inputs,
FunctionInstance& instance);
template <typename T> template <typename T>
void op_engine(const NodeWrapper& node_wrapper, void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out, const std::vector<void*>& out,
const std::vector<const void*>& args) const std::vector<const void*>& args,
FunctionInstance& instance)
{ {
const Node& node = node_wrapper.get_node(); const Node& node = node_wrapper.get_node();
std::string node_op = node.description(); std::string node_op = node.description();
...@@ -362,15 +372,15 @@ private: ...@@ -362,15 +372,15 @@ private:
} }
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
if (m_states.count(&node) == 0) if (instance.m_states.count(&node) == 0)
{ {
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node); const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
m_states[&node] = std::unique_ptr<ngraph::RNGState>( instance.m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
} }
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]); bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]);
auto state = m_states.at(&node).get(); auto state = instance.m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>( reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training); reinterpret_cast<T*>(out[0]), element_count, state, training);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment