Unverified Commit 3d28d06a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Let dist interface control logging output (#2930)

parent d19ae275
......@@ -31,6 +31,7 @@ namespace ngraph
virtual const std::string& get_name() const = 0;
virtual int get_size() = 0;
virtual int get_rank() = 0;
virtual void log_print(const std::string& timestamp, const std::vector<char>& buf) = 0;
virtual void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) = 0;
......
......@@ -60,6 +60,11 @@ namespace ngraph
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessIdx());
}
void log_print(const std::string& timestamp, const std::vector<char>& buf) override
{
std::printf("%s [MLSL RANK: %d]: %s\n", timestamp.c_str(), get_rank(), buf.data());
}
void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override
{
......
......@@ -16,6 +16,7 @@
#pragma once
#include <cstdio>
#include <string>
#include "ngraph/distributed.hpp"
......@@ -30,6 +31,10 @@ namespace ngraph
const std::string& get_name() const override { return m_name; }
int get_size() override { return 0; }
int get_rank() override { return 0; }
void log_print(const std::string& timestamp, const std::vector<char>& buf) override
{
std::printf("%s: %s\n", timestamp.c_str(), buf.data());
}
void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override
{
......
......@@ -16,6 +16,7 @@
#pragma once
#include <cstdio>
#include <iostream>
#include "ngraph/distributed.hpp"
......@@ -70,6 +71,12 @@ namespace ngraph
return rank;
}
void log_print(const std::string& timestamp, const std::vector<char>& buf) override
{
std::printf(
"%s [OpenMPI RANK: %d]: %s\n", timestamp.c_str(), get_rank(), buf.data());
}
void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override
{
......
......@@ -112,11 +112,7 @@ void ngraph::LogPrintf(const char* fmt, ...)
std::vsnprintf(buf.data(), buf.size(), fmt, args2);
#pragma GCC diagnostic pop
va_end(args2);
std::printf("%s [RANK: %d]: %s\n",
get_timestamp().c_str(),
get_distributed_interface()->get_rank(),
buf.data());
get_distributed_interface()->log_print(get_timestamp(), buf);
}
// This function will be executed only once during startup (loading of the DSO)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment