Unverified Commit c2e46149 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Refactor distributed to be isolated to a few files. (#2828)

* Refactor distributed to be isolated to a few files.

* Fix type-o

* style

* review comments

* type-o

* typo

* Return name
parent a33ad828
......@@ -59,6 +59,8 @@ set (SRC
descriptor/tensor.hpp
dimension.cpp
dimension.hpp
distributed.cpp
distributed.hpp
except.hpp
file_util.cpp
file_util.hpp
......@@ -427,16 +429,11 @@ if(NGRAPH_JSON_ENABLE)
list(APPEND SRC serializer.cpp serializer.hpp event_tracing.cpp event_tracing.hpp)
endif()
if(NGRAPH_DISTRIBUTED_ENABLE)
list(APPEND SRC distributed.cpp distributed.hpp)
endif()
configure_file(version.in.hpp version.hpp)
add_library(ngraph SHARED ${SRC})
if(NGRAPH_DISTRIBUTED_ENABLE)
target_sources(ngraph PRIVATE distributed.cpp)
if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_include_directories(ngraph SYSTEM PRIVATE libmlsl)
target_link_libraries(ngraph PRIVATE libmlsl)
......
......@@ -50,21 +50,6 @@ endif()
list(APPEND HEADER_SEARCH_DEFINES CLANG_BUILTIN_HEADERS_PATH="${CLANG_INCLUDE_DIR}")
list(APPEND HEADER_SEARCH_DEFINES NGRAPH_HEADERS_PATH="${NGRAPH_INCLUDE_PATH}")
if(NGRAPH_DISTRIBUTED_ENABLE)
if (NGRAPH_DISTRIBUTED_MLSL_ENABLE)
get_target_property(MLSL_INCLUDE_DIR libmlsl INTERFACE_INCLUDE_DIRECTORIES)
list(APPEND HEADER_SEARCH_DEFINES MLSL_HEADER_PATH="${MLSL_INCLUDE_DIR}")
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
# MPI_C_INCLUDE_PATH might have a list of directories separated by a semicolon
# Escape the semicolon to prevent cmake from interpreting the string as a list
string(REPLACE ";" "\\\;" MPI_C_INCLUDE_PATH_ESCAPED "${MPI_C_INCLUDE_PATH}")
list(APPEND HEADER_SEARCH_DEFINES MPI_HEADER_PATH="${MPI_C_INCLUDE_PATH_ESCAPED}")
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif()
if(NGRAPH_GPU_ENABLE)
find_package(CUDA 8 REQUIRED)
find_package(CUDNN 7 REQUIRED)
......
......@@ -558,16 +558,6 @@ void codegen::CompilerCore::configure_search_path()
// Only needed for GPU backend
add_header_search_path(CUDNN_HEADER_PATHS);
#endif
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
add_header_search_path(MLSL_HEADER_PATH);
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
add_header_search_path(MPI_HEADER_PATH);
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
#endif
}
void codegen::CompilerCore::load_headers_from_resource()
......
......@@ -14,90 +14,36 @@
// limitations under the License.
//*****************************************************************************
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/distributed.hpp"
#include "ngraph/distributed/mlsl.hpp"
#include "ngraph/distributed/null.hpp"
#include "ngraph/distributed/open_mpi.hpp"
#include "ngraph/log.hpp"
using namespace ngraph;
ngraph::Distributed::Distributed()
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (!MLSL::Environment::GetEnv().IsInitialized())
{
MLSL::Environment::GetEnv().Init(nullptr, nullptr);
m_init_comm = true;
}
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI_Init(NULL, NULL);
m_init_comm = true;
}
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
}
static std::unique_ptr<DistributedInterface> s_distributed_interface;
ngraph::Distributed::~Distributed()
void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface)
{
if (m_init_comm == true)
{
finalize();
}
NGRAPH_DEBUG << "Setting distributed interfsce to: " << distributed_interface->get_name();
s_distributed_interface = std::move(distributed_interface);
}
void ngraph::Distributed::finalize()
DistributedInterface* ngraph::get_distributed_interface()
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (MLSL::Environment::GetEnv().IsInitialized())
{
MLSL::Environment::GetEnv().Finalize();
}
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int flag = 0;
MPI_Initialized(&flag);
if (flag)
if (0 == s_distributed_interface)
{
MPI_Finalize();
}
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
}
int ngraph::Distributed::get_size() const
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessCount());
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
return size;
#ifdef NGRAPH_DISTRIBUTED_OMPI_ENABLE
set_distributed_interface(std::unique_ptr<DistributedInterface>(
new ngraph::distributed::OpenMPIDistributedInterface()));
#elif defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
set_distributed_interface(std::unique_ptr<DistributedInterface>(
new ngraph::distributed::MLSLDistributedInterface()));
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
}
int ngraph::Distributed::get_rank() const
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessIdx());
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
#else
throw ngraph_error("Distributed Library not supported/mentioned");
set_distributed_interface(std::unique_ptr<DistributedInterface>(
new ngraph::distributed::NullDistributedInterface()));
#endif
}
return s_distributed_interface.get();
}
#endif
......@@ -17,19 +17,26 @@
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
class Distributed
class DistributedInterface
{
public:
Distributed();
~Distributed();
int get_size() const;
int get_rank() const;
virtual ~DistributedInterface() {}
virtual const std::string& get_name() const = 0;
virtual int get_size() = 0;
virtual int get_rank() = 0;
private:
bool m_init_comm = false;
void finalize();
virtual void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) = 0;
virtual void broadcast(void* in, element::Type_t element_type, size_t count) = 0;
};
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
DistributedInterface* get_distributed_interface();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <string>
#include <mlsl.hpp>
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace distributed
{
class MLSLDistributedInterface : public DistributedInterface
{
public:
MLSLDistributedInterface(const std::string& name = "MLSL")
: m_name(name)
{
if (!MLSL::Environment::GetEnv().IsInitialized())
{
MLSL::Environment::GetEnv().Init(nullptr, nullptr);
}
}
~MLSLDistributedInterface() override
{
if (MLSL::Environment::GetEnv().IsInitialized())
{
MLSL::Environment::GetEnv().Finalize();
}
}
const std::string& get_name() const override { return m_name; }
int get_size() override
{
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessCount());
}
int get_rank() override
{
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessIdx());
}
void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override
{
auto data_type = MLSL::DT_FLOAT;
if (element_type == element::Type_t::f32)
{
data_type = MLSL::DT_FLOAT;
}
else if (element_type == element::Type_t::f64)
{
data_type = MLSL::DT_DOUBLE;
}
else
{
throw std::runtime_error("AllReduce op supports only f32 and f64 types");
}
MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req =
distribution->AllReduce(in, out, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA);
env.Wait(req);
env.DeleteDistribution(distribution);
}
void broadcast(void* in, element::Type_t element_type, size_t count) override
{
auto data_type = MLSL::DT_FLOAT;
if (element_type == element::Type_t::f64)
{
data_type = MLSL::DT_DOUBLE;
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"BroadcastDistributed op supports only f32 and f64 types");
}
MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req = distribution->Bcast(in, count, data_type, 0, MLSL::GT_DATA);
env.Wait(req);
env.DeleteDistribution(distribution);
}
protected:
std::string m_name{"MLSL"};
};
}
}
#endif
......@@ -14,19 +14,35 @@
// limitations under the License.
//*****************************************************************************
#include <iostream>
#pragma once
namespace ngraph_dist_setup
{
static int distributed_comm_size;
static int distributed_comm_rank;
}
#include <string>
#include "ngraph/distributed.hpp"
#include "ngraph/except.hpp"
class DistributedSetup
namespace ngraph
{
public:
int get_comm_size();
int get_comm_rank();
void set_comm_size(int comm_size);
void set_comm_rank(int comm_rank);
};
namespace distributed
{
class NullDistributedInterface : public DistributedInterface
{
const std::string& get_name() const override { return m_name; }
int get_size() override { return 0; }
int get_rank() override { return 0; }
void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
void broadcast(void* in, element::Type_t element_type, size_t count) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
protected:
std::string m_name{"NULL"};
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <iostream>
#include "ngraph/distributed.hpp"
#ifdef NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <string>
#include <mpi.h>
namespace ngraph
{
namespace distributed
{
class OpenMPIDistributedInterface : public DistributedInterface
{
public:
OpenMPIDistributedInterface(const std::string& name = "OpenMPI")
: m_name(name)
{
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI_Init(NULL, NULL);
}
}
~OpenMPIDistributedInterface() override
{
int flag = 0;
MPI_Initialized(&flag);
if (flag)
{
MPI_Finalize();
}
}
const std::string& get_name() const override { return m_name; }
int get_size() override
{
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
return size;
}
int get_rank() override
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override
{
auto data_type = MPI_FLOAT;
if (element_type == element::Type_t::f32)
{
data_type = MPI_FLOAT;
}
else if (element_type == element::Type_t::f64)
{
data_type = MPI_DOUBLE;
}
else
{
throw std::runtime_error("AllReduce op supports only f32 and f64 types");
}
MPI_Allreduce(in, out, count, data_type, MPI_SUM, MPI_COMM_WORLD);
}
void broadcast(void* in, element::Type_t element_type, size_t count) override
{
auto data_type = MPI_FLOAT;
if (element_type == element::Type_t::f64)
{
data_type = MPI_DOUBLE;
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"BroadcastDistributed op supports only f32 and f64 types");
}
MPI_Bcast(in, count, data_type, 0, MPI_COMM_WORLD);
}
protected:
std::string m_name;
};
}
}
#endif
......@@ -113,12 +113,10 @@ void ngraph::LogPrintf(const char* fmt, ...)
#pragma GCC diagnostic pop
va_end(args2);
#ifdef NGRAPH_DISTRIBUTED_ENABLE
ngraph::Distributed dist;
std::printf("%s [RANK: %d]: %s\n", get_timestamp().c_str(), dist.get_rank(), buf.data());
#else
std::printf("%s %s\n", get_timestamp().c_str(), buf.data());
#endif
std::printf("%s [RANK: %d]: %s\n",
get_timestamp().c_str(),
get_distributed_interface()->get_rank(),
buf.data());
}
// This function will be executed only once during startup (loading of the DSO)
......
......@@ -30,9 +30,7 @@
#endif
#include <vector>
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/distributed.hpp"
#endif
namespace ngraph
{
......
......@@ -180,21 +180,6 @@ if (NGRAPH_CPU_ENABLE)
endif()
target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS)
if(NGRAPH_DISTRIBUTED_ENABLE)
if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_include_directories(cpu_backend SYSTEM PRIVATE libmlsl)
target_link_libraries(cpu_backend PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(cpu_backend
SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(cpu_backend
PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif()
add_dependencies(cpu_backend libmkldnn ext_eigen)
target_link_libraries(cpu_backend PUBLIC ngraph libmkldnn libmkl libeigen libtbb)
if (NGRAPH_JSON_ENABLE)
......
......@@ -13,16 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/log.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
using namespace std;
......@@ -43,6 +36,7 @@ namespace ngraph
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto count = static_cast<int>(out[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum();
auto external_function_name = external_function->get_function_name();
NGRAPH_DEBUG_PRINT(
......@@ -54,56 +48,11 @@ namespace ngraph
node->get_friendly_name().c_str(),
count);
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = MLSL::DT_FLOAT;
if (args[0].get_element_type() == element::f32)
{
data_type = MLSL::DT_FLOAT;
}
else if (args[0].get_element_type() == element::f64)
{
data_type = MLSL::DT_DOUBLE;
}
auto functor = [&, count, data_type](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
MLSL::CommReq* req = ctx->mlsl_dist->AllReduce(
arg_tensor, out_tensor, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA);
ctx->mlsl_env->Wait(req);
};
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = MPI_FLOAT;
if (args[0].get_element_type() == element::f32)
{
data_type = MPI_FLOAT;
}
else if (args[0].get_element_type() == element::f64)
{
data_type = MPI_DOUBLE;
}
auto node_friendly_name = node->get_friendly_name();
auto node_name = node->get_name();
auto func_name = external_function->get_function_name();
int id = call_seq;
call_seq++;
auto functor = [&, id, count, data_type, func_name, node_friendly_name, node_name](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
NGRAPH_DEBUG_PRINT("AllReduce Execute[%d]: Function: %s Node: %s %s Size: %d",
id,
func_name.c_str(),
node_name.c_str(),
node_friendly_name.c_str(),
count);
MPI_Allreduce(
arg_tensor, out_tensor, count, data_type, MPI_SUM, MPI_COMM_WORLD);
get_distributed_interface()->all_reduce(
arg_tensor, out_tensor, data_type, count);
};
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
functors.emplace_back(functor);
}
......@@ -111,4 +60,3 @@ namespace ngraph
}
}
}
#endif
......@@ -13,13 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
......@@ -40,49 +33,14 @@ namespace ngraph
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto count = static_cast<int>(args[0].get_size());
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = MLSL::DT_FLOAT;
if (args[0].get_element_type() == element::f32)
{
data_type = MLSL::DT_FLOAT;
}
else if (args[0].get_element_type() == element::f64)
{
data_type = MLSL::DT_DOUBLE;
}
auto functor = [&, count, data_type](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
MLSL::CommReq* req =
ctx->mlsl_dist->Bcast(arg_tensor, count, data_type, 0, MLSL::GT_DATA);
ctx->mlsl_env->Wait(req);
};
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = MPI_FLOAT;
if (args[0].get_element_type() == element::f32)
{
data_type = MPI_FLOAT;
}
else if (args[0].get_element_type() == element::f64)
{
data_type = MPI_DOUBLE;
}
auto data_type = args[0].get_element_type().get_type_enum();
auto functor = [&, count, data_type](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
MPI_Bcast(arg_tensor, count, data_type, 0, MPI_COMM_WORLD);
get_distributed_interface()->broadcast(arg_tensor, data_type, count);
};
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(BroadcastDistributed);
}
}
}
#endif
......@@ -27,6 +27,7 @@
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
......@@ -107,11 +108,6 @@
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#include "ngraph/op/allreduce.hpp"
#endif
using namespace std;
using namespace ngraph;
......
......@@ -22,10 +22,6 @@
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_tracing.hpp"
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#endif
using namespace std;
using namespace ngraph;
......@@ -160,14 +156,6 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
const auto parallelism = envParallelism == nullptr ? 1 : std::atoi(envParallelism);
ctx->c = new tbb::global_control(tbb::global_control::max_allowed_parallelism, parallelism);
}
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (MLSL::Environment::GetEnv().IsInitialized())
{
ctx->mlsl_env = &MLSL::Environment::GetEnv();
ctx->mlsl_dist = ctx->mlsl_env->CreateDistribution(ctx->mlsl_env->GetProcessCount(), 1);
}
#endif
}
void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
......@@ -197,12 +185,5 @@ void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
}
delete ctx->c;
}
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (MLSL::Environment::GetEnv().IsInitialized() && ctx->mlsl_dist != nullptr)
{
ctx->mlsl_env->DeleteDistribution(ctx->mlsl_dist);
}
#endif
delete ctx;
}
......@@ -132,15 +132,6 @@
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/op/allreduce.hpp"
#endif
using namespace std;
using namespace ngraph;
......@@ -196,94 +187,23 @@ namespace ngraph
writer.block_end();
}
#ifdef NGRAPH_DISTRIBUTED_ENABLE
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::AllReduce)
{
const element::Type& element_type = args[0].get_element_type();
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = "MLSL::DT_FLOAT";
if (element_type == element::f32)
{
data_type = "MLSL::DT_FLOAT";
}
else if (element_type == element::f64)
{
data_type = "MLSL::DT_DOUBLE";
}
writer.block_begin();
writer << "MLSL::CommReq* req = ctx->mlsl_dist->AllReduce(" << args[0].get_name()
<< ", " << out[0].get_name() << ", " << out[0].get_size() << ", "
<< data_type << ", MLSL::RT_SUM, MLSL::GT_DATA);\n";
writer << "ctx->mlsl_env->Wait(req);\n";
writer.block_end();
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = "MPI_FLOAT";
if (element_type == element::f32)
{
data_type = "MPI_FLOAT";
}
else if (element_type == element::f64)
{
data_type = "MPI_DOUBLE";
}
writer.block_begin();
writer << "MPI_Allreduce(" << args[0].get_name() << ", " << out[0].get_name()
<< ", " << out[0].get_size() << ", " << data_type
<< ", MPI_SUM, MPI_COMM_WORLD);\n";
writer.block_end();
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
writer << "ngraph::get_distributed_interface()->all_reduce(" << args[0].get_name()
<< ", " << out[0].get_name() << ", "
<< "ngraph::element::Type_t::" << args[0].get_element_type().get_type_name()
<< ", " << out[0].get_size() << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BroadcastDistributed)
{
const element::Type& element_type = args[0].get_element_type();
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = "MLSL::DT_FLOAT";
if (element_type == element::f32)
{
data_type = "MLSL::DT_FLOAT";
}
else if (element_type == element::f64)
{
data_type = "MLSL::DT_DOUBLE";
}
writer.block_begin();
writer << "MLSL::CommReq* req = ctx->mlsl_dist->Bcast(" << args[0].get_name()
<< ", " << args[0].get_size() << ", " << data_type
<< ", 0, MLSL::GT_DATA);\n";
writer << "ctx->mlsl_env->Wait(req);\n";
writer.block_end();
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = "MPI_FLOAT";
if (element_type == element::f32)
{
data_type = "MPI_FLOAT";
}
else if (element_type == element::f64)
{
data_type = "MPI_DOUBLE";
}
writer.block_begin();
writer << "MPI_Bcast(" << args[0].get_name() << ", " << args[0].get_size() << ", "
<< data_type << ", 0, MPI_COMM_WORLD);\n";
writer.block_end();
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
writer << "ngraph::get_distributed_interface()->broadcast(" << args[0].get_name()
<< ", "
<< "ngraph::element::Type_t::" << args[0].get_element_type().get_type_name()
<< ", " << args[0].get_size() << ");\n;";
}
#endif
static void emitCblasSgemmBatch(CodeWriter& writer,
const Shape& shape_a,
......
......@@ -53,6 +53,7 @@
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
......@@ -186,11 +187,6 @@
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#endif
using namespace std;
using namespace ngraph;
......@@ -297,11 +293,9 @@ static StaticInitializers s_static_initializers(s_output_dir);
static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::emit<op::Add>},
#ifdef NGRAPH_DISTRIBUTED_ENABLE
{TI(ngraph::op::AllReduce), &runtime::cpu::CPU_Emitter::emit<op::AllReduce>},
{TI(ngraph::op::BroadcastDistributed),
&runtime::cpu::CPU_Emitter::emit<op::BroadcastDistributed>},
#endif
{TI(ngraph::op::MatmulBias), &runtime::cpu::CPU_Emitter::emit<op::MatmulBias>},
{TI(ngraph::op::Dot), &runtime::cpu::CPU_Emitter::emit<op::Dot>},
{TI(ngraph::op::Multiply), &runtime::cpu::CPU_Emitter::emit<op::Multiply>},
......@@ -513,21 +507,10 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
}
writer << "#include <tbb/flow_graph.h>";
}
#ifdef NGRAPH_DISTRIBUTED_ENABLE
writer << "#define NGRAPH_DISTRIBUTED_ENABLE\n";
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
writer << "#include <mlsl.hpp>\n";
writer << "#define NGRAPH_DISTRIBUTED_MLSL_ENABLE\n";
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
writer << "#include <mpi.h>\n";
writer << "#define NGRAPH_DISTRIBUTED_OMPI_ENABLE\n";
#endif
#endif
writer +=
R"(
#include <cmath>
#include "ngraph/distributed.hpp"
#include "ngraph/except.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_eigen_utils.hpp"
......
......@@ -26,10 +26,6 @@
#include <tbb/global_control.h>
#include <tbb/task_scheduler_init.h>
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#endif
namespace mkldnn
{
class primitive;
......@@ -69,10 +65,6 @@ namespace ngraph
State* const* states;
std::set<size_t> breakpoints;
size_t pc;
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
MLSL::Environment* mlsl_env;
MLSL::Distribution* mlsl_dist;
#endif
};
}
......
......@@ -22,11 +22,8 @@
#include <string>
#include <vector>
#include "ngraph/runtime/tensor.hpp"
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/runtime/reference/allreduce.hpp"
#endif
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
......
......@@ -69,6 +69,7 @@
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/allreduce.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
......@@ -77,6 +78,7 @@
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
......@@ -136,11 +138,6 @@
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/runtime/reference/allreduce.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#endif
namespace ngraph
{
namespace runtime
......@@ -234,13 +231,12 @@ private:
all->get_reduction_axes());
break;
}
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED_ENABLE
case OP_TYPEID::AllReduce:
{
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])),
static_cast<T*>(out[0]),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))));
#endif
break;
}
case OP_TYPEID::And:
......@@ -441,12 +437,10 @@ private:
broadcast_axes);
break;
}
case OP_TYPEID::BroadcastDistributed: {
#ifdef NGRAPH_DISTRIBUTED_ENABLE
Distributed dist;
int Rank_ID;
Rank_ID = dist.get_rank();
if (Rank_ID == 0)
case OP_TYPEID::BroadcastDistributed:
{
int rank_ID = get_distributed_interface()->get_rank();
if (rank_ID == 0)
{
reference::broadcastdistributed<T>(
static_cast<T*>(args[0]),
......@@ -464,8 +458,6 @@ private:
static_cast<int>(shape_size(node.get_input_shape(0))));
}
break;
#endif
break;
}
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling:
......
......@@ -23,21 +23,6 @@ if (NGRAPH_INTERPRETER_ENABLE)
endif()
target_link_libraries(interpreter_backend PUBLIC ngraph)
if(NGRAPH_DISTRIBUTED_ENABLE)
if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_include_directories(interpreter_backend SYSTEM PRIVATE libmlsl)
target_link_libraries(interpreter_backend PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(interpreter_backend
SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(interpreter_backend
PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif()
install(TARGETS interpreter_backend
LIBRARY DESTINATION "${NGRAPH_INSTALL_LIB}"
ARCHIVE DESTINATION "${NGRAPH_INSTALL_LIB}"
......
......@@ -69,6 +69,7 @@
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/allreduce.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
......@@ -79,6 +80,7 @@
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
......@@ -142,11 +144,6 @@
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/runtime/reference/allreduce.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#endif
namespace ngraph
{
namespace runtime
......@@ -239,13 +236,12 @@ private:
all->get_reduction_axes());
break;
}
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED_ENABLE
case OP_TYPEID::AllReduce:
{
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))));
#endif
break;
}
case OP_TYPEID::And:
......@@ -456,31 +452,27 @@ private:
broadcast_axes);
break;
}
case OP_TYPEID::BroadcastDistributed: {
#ifdef NGRAPH_DISTRIBUTED_ENABLE
Distributed dist;
int Rank_ID;
Rank_ID = dist.get_rank();
if (Rank_ID == 0)
case OP_TYPEID::BroadcastDistributed:
{
int rank_ID;
rank_ID = get_distributed_interface()->get_rank();
if (rank_ID == 0)
{
reference::broadcastdistributed<T>(
args[0]->get_data_ptr<T>(),
node.get_input_element_type(0),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))));
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) *
sizeof(node.get_input_element_type(0));
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
}
else
{
reference::broadcastdistributed<T>(
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))));
}
break;
#endif
break;
}
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling:
......
......@@ -16,13 +16,7 @@
#pragma once
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/type/element_type.hpp"
#include "ngraph/distributed.hpp"
namespace ngraph
{
......@@ -31,53 +25,10 @@ namespace ngraph
namespace reference
{
template <typename T>
void allreduce(T* arg, T* out, const element::Type element_type, int count)
void allreduce(T* arg, T* out, const element::Type_t element_type, int count)
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = MLSL::DT_FLOAT;
if (element_type == element::f32)
{
data_type = MLSL::DT_FLOAT;
}
else if (element_type == element::f64)
{
data_type = MLSL::DT_DOUBLE;
}
else
{
throw std::runtime_error("AllReduce op supports only f32 and f64 types");
}
MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req = distribution->AllReduce(
arg, out, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA);
env.Wait(req);
env.DeleteDistribution(distribution);
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = MPI_FLOAT;
if (element_type == element::f32)
{
data_type = MPI_FLOAT;
}
else if (element_type == element::f64)
{
data_type = MPI_DOUBLE;
}
else
{
throw std::runtime_error("AllReduce op supports only f32 and f64 types");
}
MPI_Allreduce(arg, out, count, data_type, MPI_SUM, MPI_COMM_WORLD);
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
get_distributed_interface()->all_reduce(arg, out, element_type, count);
}
}
}
}
#endif
......@@ -16,13 +16,9 @@
#pragma once
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/type/element_type.hpp"
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
......@@ -31,45 +27,10 @@ namespace ngraph
namespace reference
{
template <typename T>
void broadcastdistributed(T* arg, const element::Type element_type, int count)
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = MLSL::DT_FLOAT;
if (element_type == element::f64)
{
data_type = MLSL::DT_DOUBLE;
}
else if (element_type != element::f32)
{
throw std::runtime_error(
"BroadcastDistributed op supports only f32 and f64 types");
}
MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req = distribution->Bcast(arg, count, data_type, 0, MLSL::GT_DATA);
env.Wait(req);
env.DeleteDistribution(distribution);
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = MPI_FLOAT;
if (element_type == element::f64)
void broadcastdistributed(T* arg, const element::Type_t element_type, int count)
{
data_type = MPI_DOUBLE;
}
else if (element_type != element::f32)
{
throw std::runtime_error(
"BroadcastDistributed op supports only f32 and f64 types");
}
MPI_Bcast(arg, count, data_type, 0, MPI_COMM_WORLD);
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
get_distributed_interface()->broadcast(arg, element_type, count);
}
}
}
}
#endif
......@@ -41,13 +41,18 @@ NGRAPH_API const element::Type element::u64(element::Type_t::u64);
class TypeInfo
{
public:
TypeInfo(
size_t bitwidth, bool is_real, bool is_signed, bool is_quantized, const std::string& cname)
TypeInfo(size_t bitwidth,
bool is_real,
bool is_signed,
bool is_quantized,
const std::string& cname,
const std::string& type_name)
: m_bitwidth{bitwidth}
, m_is_real{is_real}
, m_is_signed{is_signed}
, m_is_quantized{is_quantized}
, m_cname{cname}
, m_type_name{type_name}
{
}
size_t m_bitwidth;
......@@ -55,26 +60,28 @@ public:
bool m_is_signed;
bool m_is_quantized;
std::string m_cname;
std::string m_type_name;
};
static const map<element::Type_t, const TypeInfo>& get_type_info_map()
{
static map<element::Type_t, const TypeInfo> s_type_info_map{
{element::Type_t::undefined,
TypeInfo(std::numeric_limits<size_t>::max(), false, false, false, "undefined")},
{element::Type_t::dynamic, TypeInfo(0, false, false, false, "dynamic")},
{element::Type_t::boolean, TypeInfo(8, false, true, false, "char")},
{element::Type_t::bf16, TypeInfo(16, true, true, false, "bfloat16")},
{element::Type_t::f32, TypeInfo(32, true, true, false, "float")},
{element::Type_t::f64, TypeInfo(64, true, true, false, "double")},
{element::Type_t::i8, TypeInfo(8, false, true, true, "int8_t")},
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t")},
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t")},
TypeInfo(
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
{element::Type_t::dynamic, TypeInfo(0, false, false, false, "dynamic", "dynamic")},
{element::Type_t::boolean, TypeInfo(8, false, true, false, "char", "boolean")},
{element::Type_t::bf16, TypeInfo(16, true, true, false, "bfloat16", "bf16")},
{element::Type_t::f32, TypeInfo(32, true, true, false, "float", "f32")},
{element::Type_t::f64, TypeInfo(64, true, true, false, "double", "f64")},
{element::Type_t::i8, TypeInfo(8, false, true, true, "int8_t", "i8")},
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")},
};
return s_type_info_map;
};
......@@ -137,6 +144,11 @@ size_t element::Type::hash() const
return static_cast<size_t>(m_type);
}
const std::string& element::Type::get_type_name() const
{
return get_type_info_map().at(m_type).m_type_name;
}
namespace ngraph
{
namespace element
......
......@@ -84,6 +84,8 @@ namespace ngraph
bool is_signed() const;
bool is_quantized() const;
size_t bitwidth() const;
// The name of this type, the enum name of this type
const std::string& get_type_name() const;
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
bool operator<(const Type& other) const;
......
......@@ -44,16 +44,4 @@ if (NGRAPH_GENERIC_CPU_ENABLE)
target_link_libraries(nbench PRIVATE gcpu_backend)
endif()
if(NGRAPH_DISTRIBUTED_ENABLE)
if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_link_libraries(nbench PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(nbench SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(nbench PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif()
install(TARGETS nbench RUNTIME DESTINATION ${NGRAPH_INSTALL_BIN})
......@@ -24,6 +24,7 @@
#include <iomanip>
#include "benchmark.hpp"
#include "ngraph/distributed.hpp"
#include "ngraph/except.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
......@@ -35,10 +36,6 @@
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#if defined NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/distributed.hpp"
#endif
using namespace std;
using namespace ngraph;
......@@ -290,14 +287,6 @@ OPTIONS
return 1;
}
#if defined NGRAPH_DISTRIBUTED_ENABLE
unique_ptr<ngraph::Distributed> dist(new ngraph::Distributed());
if (dist->get_size() == 1)
{
dist.reset();
}
#endif
vector<string> models;
if (!directory.empty())
{
......@@ -461,12 +450,5 @@ OPTIONS
print_results(aggregate_perf_data, timing_detail);
}
#if defined NGRAPH_DISTRIBUTED_ENABLE
if (dist)
{
dist.reset();
}
#endif
return rc;
}
......@@ -77,10 +77,6 @@ if(NOT WIN32 AND NGRAPH_TOOLS_ENABLE)
list(APPEND SRC tools.cpp)
endif()
if(NGRAPH_DISTRIBUTED_ENABLE)
list(APPEND SRC distributed_setup.cpp)
endif()
set_source_files_properties(includes.cpp PROPERTIES COMPILE_DEFINITIONS
NGRAPH_INCLUDES="${PROJECT_SOURCE_DIR}/src/ngraph")
......@@ -217,19 +213,6 @@ if(NGRAPH_ADDRESS_SANITIZER)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -fsanitize=address -fno-omit-frame-pointer")
endif()
if(NGRAPH_DISTRIBUTED_ENABLE)
if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_link_libraries(unit-test PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(unit-test
SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(unit-test PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif()
target_link_libraries(unit-test PRIVATE ngraph_test_util)
target_link_libraries(unit-test PRIVATE ngraph libgtest)
if (NGRAPH_JSON_ENABLE)
......
......@@ -19,7 +19,6 @@
#include "gtest/gtest.h"
#include "distributed_setup.hpp"
#include "ngraph/distributed.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp"
......@@ -32,8 +31,7 @@ using namespace ngraph;
TEST(distributed_${BACKEND_NAME}, allreduce)
{
DistributedSetup distsetup;
auto comm_size = distsetup.get_comm_size();
auto comm_size = get_distributed_interface()->get_size();
if (comm_size > 1)
{
auto shape = Shape{2, 2};
......@@ -69,8 +67,7 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
auto result = backend->create_tensor(element::f32, shape);
copy_data(result, vector<float>(4, 0));
DistributedSetup distsetup;
auto processIdx = distsetup.get_comm_rank();
auto processIdx = get_distributed_interface()->get_rank();
if (processIdx == 0)
{
copy_data(result, v);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "distributed_setup.hpp"
void DistributedSetup::set_comm_size(int comm_size)
{
ngraph_dist_setup::distributed_comm_size = comm_size;
}
void DistributedSetup::set_comm_rank(int comm_rank)
{
ngraph_dist_setup::distributed_comm_rank = comm_rank;
}
int DistributedSetup::get_comm_size()
{
return ngraph_dist_setup::distributed_comm_size;
}
int DistributedSetup::get_comm_rank()
{
return ngraph_dist_setup::distributed_comm_rank;
}
......@@ -22,23 +22,8 @@
using namespace std;
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#include <memory>
#include "ngraph/distributed.hpp"
#include "distributed_setup.hpp"
#endif
int main(int argc, char** argv)
{
#ifdef NGRAPH_DISTRIBUTED_ENABLE
unique_ptr<ngraph::Distributed> dist(new ngraph::Distributed());
DistributedSetup distributed_setup;
distributed_setup.set_comm_size(dist->get_size());
distributed_setup.set_comm_rank(dist->get_rank());
#endif
const char* exclude = "--gtest_filter=-benchmark.*";
vector<char*> argv_vector;
argv_vector.push_back(argv[0]);
......@@ -55,12 +40,6 @@ int main(int argc, char** argv)
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now() - start);
NGRAPH_DEBUG_PRINT("[MAIN] Tests finished: Time: %d ms Exit code: %d", elapsed.count(), rc);
#ifdef NGRAPH_DISTRIBUTED_ENABLE
if (dist)
{
dist.reset();
}
#endif
return rc;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment