Commit 9530267d authored by Robert Kimball's avatar Robert Kimball

minimal change update

parent ef2e0118
......@@ -75,6 +75,8 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
bool enable_performance_collection)
{
{
FunctionInstance& instance = m_function_instance;
instance.m_is_compiled = true;
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
......@@ -83,20 +85,21 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass_manager.run_passes(function);
size_t memory_pool_size = function->get_temporary_pool_size();
m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, get_alignment()));
instance.m_temporary_memory.reset(new AlignedBuffer(memory_pool_size, get_alignment()));
for (const shared_ptr<Node>& node : function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
instance.m_wrapped_nodes.emplace_back(node);
}
}
set_parameters_and_results(*function);
}
bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
FunctionInstance& instance = m_function_instance;
// convert inputs to HostTensor
vector<void*> func_inputs;
vector<shared_ptr<runtime::HostTensor>> htv_inputs;
......@@ -106,7 +109,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
func_inputs.push_back(static_cast<void*>(host_tensor->get_data_ptr()));
htv_inputs.push_back(host_tensor);
}
if (m_nan_check_enabled)
if (instance.m_nan_check_enabled)
{
perform_nan_check(htv_inputs);
}
......@@ -144,7 +147,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
}
// for each ordered op in the graph
for (const NodeWrapper& wrapped : m_wrapped_nodes)
for (const NodeWrapper& wrapped : instance.m_wrapped_nodes)
{
const Node* op = &wrapped.get_node();
auto type_id = wrapped.get_typeid();
......@@ -178,7 +181,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
if (it == tensor_map.end())
{
auto offset = op->get_output_tensor(i).get_pool_offset();
host_tensor = get_temporary_pointer(offset);
host_tensor = instance.get_temporary_pointer(offset);
tensor_map.insert({tensor, host_tensor});
}
else
......@@ -217,16 +220,16 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
}
#pragma GCC diagnostic pop
if (m_performance_counters_enabled)
if (instance.m_performance_counters_enabled)
{
m_timer_map[op].start();
instance.m_timer_map[op].start();
}
generate_calls(type, wrapped, op_outputs, op_inputs);
if (m_performance_counters_enabled)
generate_calls(type, wrapped, op_outputs, op_inputs, instance);
if (instance.m_performance_counters_enabled)
{
m_timer_map[op].stop();
instance.m_timer_map[op].stop();
}
if (m_nan_check_enabled)
if (instance.m_nan_check_enabled)
{
perform_nan_check(htv_outputs, op);
}
......@@ -238,22 +241,23 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
void runtime::interpreter::INTExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op,
const vector<void*>& outputs,
const vector<const void*>& inputs)
const vector<const void*>& inputs,
FunctionInstance& instance)
{
stringstream ss;
switch (type.get_type_enum())
{
case element::Type_t::boolean: op_engine<char>(op, outputs, inputs); break;
case element::Type_t::f32: op_engine<float>(op, outputs, inputs); break;
case element::Type_t::f64: op_engine<double>(op, outputs, inputs); break;
case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs); break;
case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs); break;
case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs); break;
case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs); break;
case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs); break;
case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs); break;
case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs); break;
case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs); break;
case element::Type_t::boolean: op_engine<char>(op, outputs, inputs, instance); break;
case element::Type_t::f32: op_engine<float>(op, outputs, inputs, instance); break;
case element::Type_t::f64: op_engine<double>(op, outputs, inputs, instance); break;
case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs, instance); break;
case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs, instance); break;
case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs, instance); break;
case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs, instance); break;
case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs, instance); break;
case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs, instance); break;
case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs, instance); break;
case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs, instance); break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
......@@ -262,11 +266,18 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
}
}
void runtime::interpreter::INTExecutable::set_nan_check(bool enable)
{
FunctionInstance& instance = m_function_instance;
instance.m_nan_check_enabled = enable;
}
vector<runtime::PerformanceCounter>
runtime::interpreter::INTExecutable::get_performance_data() const
{
vector<runtime::PerformanceCounter> rc;
for (const pair<const Node*, stopwatch> p : m_timer_map)
const FunctionInstance& instance = m_function_instance;
for (const pair<const Node*, stopwatch> p : instance.m_timer_map)
{
rc.emplace_back(p.first->get_name().c_str(),
p.second.get_total_microseconds(),
......
......@@ -180,11 +180,16 @@ public:
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override;
void set_nan_check(bool value) { m_nan_check_enabled = value; }
void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override;
private:
int get_alignment() const { return 64; }
class FunctionInstance
{
public:
bool m_is_compiled = false;
bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map;
......@@ -193,18 +198,23 @@ private:
std::shared_ptr<AlignedBuffer> m_temporary_memory;
void* get_temporary_pointer(size_t offset) { return m_temporary_memory->get_ptr(offset); }
} m_function_instance;
std::set<std::string> m_unsupported_op_name_list;
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr);
void generate_calls(const element::Type& type,
const NodeWrapper& op,
const std::vector<void*>& outputs,
const std::vector<const void*>& inputs);
const std::vector<const void*>& inputs,
FunctionInstance& instance);
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out,
const std::vector<const void*>& args)
const std::vector<const void*>& args,
FunctionInstance& instance)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
......@@ -362,15 +372,15 @@ private:
}
case OP_TYPEID::GenerateMask:
{
if (m_states.count(&node) == 0)
if (instance.m_states.count(&node) == 0)
{
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
m_states[&node] = std::unique_ptr<ngraph::RNGState>(
instance.m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
}
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]);
auto state = m_states.at(&node).get();
auto state = instance.m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment