Commit 26590326 authored by Avijit's avatar Avijit Committed by Scott Cyphers

Callback for writing events (#3769)

* Added a callback registration to the event class so that frameworks can decide how to write the events. Note: This is FULLY backwards compatible i.e., no change in the API so - won't break existing users

* Fixed compilation error

* Attempting to fix broken Windows build

* Fixed a race condition in the test

* Attempt to fix windows build again
parent 3b40c127
...@@ -29,16 +29,22 @@ static bool read_tracing_env_var() ...@@ -29,16 +29,22 @@ static bool read_tracing_env_var()
return (std::getenv("NGRAPH_ENABLE_TRACING") != nullptr); return (std::getenv("NGRAPH_ENABLE_TRACING") != nullptr);
} }
mutex ngraph::Event::s_file_mutex; NGRAPH_API mutex ngraph::Event::s_file_mutex;
ofstream ngraph::Event::s_event_log; NGRAPH_API ofstream ngraph::Event::s_event_log;
bool ngraph::Event::s_tracing_enabled = read_tracing_env_var(); NGRAPH_API bool ngraph::Event::s_tracing_enabled = read_tracing_env_var();
NGRAPH_API bool ngraph::Event::s_event_writer_registered = false;
NGRAPH_API std::function<void(const ngraph::Event& event)> ngraph::Event::s_event_writer;
void ngraph::Event::write_trace(const ngraph::Event& event) void ngraph::Event::write_trace(const ngraph::Event& event)
{ {
if (is_tracing_enabled()) if (is_tracing_enabled())
{ {
lock_guard<mutex> lock(s_file_mutex); lock_guard<mutex> lock(s_file_mutex);
if (s_event_writer_registered)
{
s_event_writer(event);
return;
}
static bool so_initialized = false; static bool so_initialized = false;
if (!so_initialized) if (!so_initialized)
{ {
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#else #else
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <functional>
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -90,15 +92,31 @@ namespace ngraph ...@@ -90,15 +92,31 @@ namespace ngraph
m_stop = std::chrono::high_resolution_clock::now(); m_stop = std::chrono::high_resolution_clock::now();
} }
const std::string& get_name() const { return m_name; }
const std::string& get_category() const { return m_category; }
const std::string& get_agrs() const { return m_args; }
const std::chrono::time_point<std::chrono::high_resolution_clock>& get_start() const
{
return m_start;
}
const std::chrono::time_point<std::chrono::high_resolution_clock>& get_stop() const
{
return m_stop;
}
static void register_event_writer(std::function<void(const Event& event)> callback)
{
std::lock_guard<std::mutex> lock(s_file_mutex);
s_event_writer_registered = true;
s_event_writer = callback;
}
static void write_trace(const Event& event); static void write_trace(const Event& event);
static bool is_tracing_enabled() { return s_tracing_enabled; } static bool is_tracing_enabled() { return s_tracing_enabled; }
static void enable_event_tracing(); static void enable_event_tracing();
static void disable_event_tracing(); static void disable_event_tracing();
std::string to_json() const; std::string to_json() const;
Event(const Event&) = delete;
Event& operator=(Event const&) = delete;
private: private:
int m_pid; int m_pid;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start; std::chrono::time_point<std::chrono::high_resolution_clock> m_start;
...@@ -108,9 +126,11 @@ namespace ngraph ...@@ -108,9 +126,11 @@ namespace ngraph
std::string m_category; std::string m_category;
std::string m_args; std::string m_args;
static std::mutex s_file_mutex; NGRAPH_API static std::mutex s_file_mutex;
static std::ofstream s_event_log; NGRAPH_API static std::ofstream s_event_log;
static bool s_tracing_enabled; NGRAPH_API static bool s_tracing_enabled;
NGRAPH_API static std::function<void(const Event& event)> s_event_writer;
NGRAPH_API static bool s_event_writer_registered;
}; };
} // namespace ngraph } // namespace ngraph
...@@ -29,7 +29,6 @@ using namespace std; ...@@ -29,7 +29,6 @@ using namespace std;
TEST(event_tracing, event_file) TEST(event_tracing, event_file)
{ {
// Set the environment variable to ensure logging
ngraph::Event::enable_event_tracing(); ngraph::Event::enable_event_tracing();
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (auto i = 0; i < 10; i++) for (auto i = 0; i < 10; i++)
...@@ -39,16 +38,14 @@ TEST(event_tracing, event_file) ...@@ -39,16 +38,14 @@ TEST(event_tracing, event_file)
std::ostringstream oss; std::ostringstream oss;
oss << "Event: " << id; oss << "Event: " << id;
ngraph::Event event(oss.str(), "Dummy", "none"); ngraph::Event event(oss.str(), "Dummy", "none");
std::this_thread::sleep_for(std::chrono::milliseconds(200)); std::this_thread::sleep_for(std::chrono::milliseconds(20));
event.Stop(); event.Stop();
ngraph::Event::write_trace(event); ngraph::Event::write_trace(event);
}); });
std::this_thread::sleep_for(std::chrono::milliseconds(200)); std::this_thread::sleep_for(std::chrono::milliseconds(20));
threads.push_back(std::move(next_thread)); threads.push_back(std::move(next_thread));
} }
std::this_thread::sleep_for(std::chrono::milliseconds(200));
for (auto& next : threads) for (auto& next : threads)
{ {
next.join(); next.join();
...@@ -62,3 +59,55 @@ TEST(event_tracing, event_file) ...@@ -62,3 +59,55 @@ TEST(event_tracing, event_file)
// TODO // TODO
ngraph::Event::disable_event_tracing(); ngraph::Event::disable_event_tracing();
} }
TEST(event_tracing, event_writer_callback)
{
// Create the event writer
vector<ngraph::Event> event_list;
auto event_writer = [&](const ngraph::Event& event) { event_list.push_back(event); };
map<string, unique_ptr<ngraph::Event>> expected_event_table;
mutex expected_event_table_mtx;
ngraph::Event::enable_event_tracing();
ngraph::Event::register_event_writer(event_writer);
auto worker = [&](int worker_id) {
std::ostringstream oss;
oss << "Event: " << worker_id;
unique_ptr<ngraph::Event> event(new ngraph::Event(oss.str(), "Dummy", "none"));
std::this_thread::sleep_for(std::chrono::milliseconds(20));
event->Stop();
ngraph::Event::write_trace(*event);
lock_guard<mutex> lock(expected_event_table_mtx);
expected_event_table[event->get_name()] = move(event);
};
std::vector<std::thread> threads;
for (int i = 0; i < 10; i++)
{
std::thread thread_next(worker, i);
threads.push_back(move(thread_next));
}
for (auto& next : threads)
{
next.join();
}
ngraph::Event::disable_event_tracing();
// Now validate the events
ASSERT_EQ(10, event_list.size());
ASSERT_EQ(10, expected_event_table.size());
for (const auto& next_event : event_list)
{
const auto& expected_event_key = expected_event_table.find(next_event.get_name());
EXPECT_TRUE(expected_event_key != expected_event_table.end());
EXPECT_EQ(expected_event_key->second->get_name(), next_event.get_name());
EXPECT_EQ(expected_event_key->second->get_start(), next_event.get_start());
EXPECT_EQ(expected_event_key->second->get_stop(), next_event.get_stop());
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment