Commit 26590326 authored by Avijit's avatar Avijit Committed by Scott Cyphers

Callback for writing events (#3769)

* Added a callback registration to the event class so that frameworks can decide how to write the events. Note: This is FULLY backwards compatible i.e., no change in the API so - won't break existing users

* Fixed compilation error

* Attempting to fix broken Windows build

* Fixed a race condition in the test

* Attempt to fix windows build again
parent 3b40c127
......@@ -29,16 +29,22 @@ static bool read_tracing_env_var()
return (std::getenv("NGRAPH_ENABLE_TRACING") != nullptr);
}
mutex ngraph::Event::s_file_mutex;
ofstream ngraph::Event::s_event_log;
bool ngraph::Event::s_tracing_enabled = read_tracing_env_var();
NGRAPH_API mutex ngraph::Event::s_file_mutex;
NGRAPH_API ofstream ngraph::Event::s_event_log;
NGRAPH_API bool ngraph::Event::s_tracing_enabled = read_tracing_env_var();
NGRAPH_API bool ngraph::Event::s_event_writer_registered = false;
NGRAPH_API std::function<void(const ngraph::Event& event)> ngraph::Event::s_event_writer;
void ngraph::Event::write_trace(const ngraph::Event& event)
{
if (is_tracing_enabled())
{
lock_guard<mutex> lock(s_file_mutex);
if (s_event_writer_registered)
{
s_event_writer(event);
return;
}
static bool so_initialized = false;
if (!so_initialized)
{
......
......@@ -30,6 +30,8 @@
#else
#include <unistd.h>
#endif
#include <functional>
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
......@@ -90,15 +92,31 @@ namespace ngraph
m_stop = std::chrono::high_resolution_clock::now();
}
const std::string& get_name() const { return m_name; }
const std::string& get_category() const { return m_category; }
const std::string& get_agrs() const { return m_args; }
const std::chrono::time_point<std::chrono::high_resolution_clock>& get_start() const
{
return m_start;
}
const std::chrono::time_point<std::chrono::high_resolution_clock>& get_stop() const
{
return m_stop;
}
static void register_event_writer(std::function<void(const Event& event)> callback)
{
std::lock_guard<std::mutex> lock(s_file_mutex);
s_event_writer_registered = true;
s_event_writer = callback;
}
static void write_trace(const Event& event);
static bool is_tracing_enabled() { return s_tracing_enabled; }
static void enable_event_tracing();
static void disable_event_tracing();
std::string to_json() const;
Event(const Event&) = delete;
Event& operator=(Event const&) = delete;
private:
int m_pid;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start;
......@@ -108,9 +126,11 @@ namespace ngraph
std::string m_category;
std::string m_args;
static std::mutex s_file_mutex;
static std::ofstream s_event_log;
static bool s_tracing_enabled;
NGRAPH_API static std::mutex s_file_mutex;
NGRAPH_API static std::ofstream s_event_log;
NGRAPH_API static bool s_tracing_enabled;
NGRAPH_API static std::function<void(const Event& event)> s_event_writer;
NGRAPH_API static bool s_event_writer_registered;
};
} // namespace ngraph
......@@ -29,7 +29,6 @@ using namespace std;
TEST(event_tracing, event_file)
{
// Set the environment variable to ensure logging
ngraph::Event::enable_event_tracing();
std::vector<std::thread> threads;
for (auto i = 0; i < 10; i++)
......@@ -39,16 +38,14 @@ TEST(event_tracing, event_file)
std::ostringstream oss;
oss << "Event: " << id;
ngraph::Event event(oss.str(), "Dummy", "none");
std::this_thread::sleep_for(std::chrono::milliseconds(200));
std::this_thread::sleep_for(std::chrono::milliseconds(20));
event.Stop();
ngraph::Event::write_trace(event);
});
std::this_thread::sleep_for(std::chrono::milliseconds(200));
std::this_thread::sleep_for(std::chrono::milliseconds(20));
threads.push_back(std::move(next_thread));
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
for (auto& next : threads)
{
next.join();
......@@ -62,3 +59,55 @@ TEST(event_tracing, event_file)
// TODO
ngraph::Event::disable_event_tracing();
}
TEST(event_tracing, event_writer_callback)
{
// Create the event writer
vector<ngraph::Event> event_list;
auto event_writer = [&](const ngraph::Event& event) { event_list.push_back(event); };
map<string, unique_ptr<ngraph::Event>> expected_event_table;
mutex expected_event_table_mtx;
ngraph::Event::enable_event_tracing();
ngraph::Event::register_event_writer(event_writer);
auto worker = [&](int worker_id) {
std::ostringstream oss;
oss << "Event: " << worker_id;
unique_ptr<ngraph::Event> event(new ngraph::Event(oss.str(), "Dummy", "none"));
std::this_thread::sleep_for(std::chrono::milliseconds(20));
event->Stop();
ngraph::Event::write_trace(*event);
lock_guard<mutex> lock(expected_event_table_mtx);
expected_event_table[event->get_name()] = move(event);
};
std::vector<std::thread> threads;
for (int i = 0; i < 10; i++)
{
std::thread thread_next(worker, i);
threads.push_back(move(thread_next));
}
for (auto& next : threads)
{
next.join();
}
ngraph::Event::disable_event_tracing();
// Now validate the events
ASSERT_EQ(10, event_list.size());
ASSERT_EQ(10, expected_event_table.size());
for (const auto& next_event : event_list)
{
const auto& expected_event_key = expected_event_table.find(next_event.get_name());
EXPECT_TRUE(expected_event_key != expected_event_table.end());
EXPECT_EQ(expected_event_key->second->get_name(), next_event.get_name());
EXPECT_EQ(expected_event_key->second->get_start(), next_event.get_start());
EXPECT_EQ(expected_event_key->second->get_stop(), next_event.get_stop());
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment