Commit 08c4c57c authored by Sandeep's avatar Sandeep Committed by Scott Cyphers

add OpenMPI support besides MLSL (#2353)

* quick fix to add openmpi as default

* add finalize to distributed class & use unit test

* use intel mlsl github link

* apply style

* address a few comments

* fix test

* update nbench cmake

* remove extras

* fix a bug

* add counter to finalize and cleanup

* test ci error

* address mlsl ci error

* update flag names, as mentioned in pr comment

* revert back the link to mlsl repo and tag

* add flag to finalize

* apply style

* debug with info

* delete when flag is true

* add distributed setup class works, tests pass

* fix style

* remove extra instance

* disable the test due to a bug

* change flag to ompi

* remove the dependency of setting NGRAPH_DISTRIBUTED_ENABLE flag

* cleanup

* change extern to static

* remove the option NGRAPH_DISTRIBUTED_ENABLE setting this flag

* formatting

* update flags not catched by ci

* make unique pointer

* remove unused bool, fix clang error
parent 13b4966b
...@@ -111,19 +111,31 @@ option(NGRAPH_INTERPRETER_ENABLE "Control the building of the INTERPRETER backen ...@@ -111,19 +111,31 @@ option(NGRAPH_INTERPRETER_ENABLE "Control the building of the INTERPRETER backen
option(NGRAPH_NOP_ENABLE "Control the building of the NOP backend" TRUE) option(NGRAPH_NOP_ENABLE "Control the building of the NOP backend" TRUE)
option(NGRAPH_GPUH_ENABLE "Control the building of the Hybrid GPU backend" FALSE) option(NGRAPH_GPUH_ENABLE "Control the building of the Hybrid GPU backend" FALSE)
option(NGRAPH_GENERIC_CPU_ENABLE "Enable build nGraph for generic CPU backend" FALSE) option(NGRAPH_GENERIC_CPU_ENABLE "Enable build nGraph for generic CPU backend" FALSE)
option(NGRAPH_DISTRIBUTED_ENABLE "Add distributed mode to the CPU backend" FALSE)
option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE) option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE)
option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE) option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE)
option(NGRAPH_DEX_ONLY "Build CPU DEX without codegen" FALSE) option(NGRAPH_DEX_ONLY "Build CPU DEX without codegen" FALSE)
option(NGRAPH_CODE_COVERAGE_ENABLE "Enable code coverage data collection" FALSE) option(NGRAPH_CODE_COVERAGE_ENABLE "Enable code coverage data collection" FALSE)
option(NGRAPH_LIB_VERSIONING_ENABLE "Enable shared library versioning" FALSE) option(NGRAPH_LIB_VERSIONING_ENABLE "Enable shared library versioning" FALSE)
option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE) option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE)
option(NGRAPH_DISTRIBUTED_MLSL_ENABLE "Add distributed MLSL mode for CPU only backend" FALSE)
option(NGRAPH_DISTRIBUTED_OMPI_ENABLE "Add distributed Open-MPI mode for all backend" FALSE)
option(NGRAPH_PLAIDML_ENABLE "Enable the PlaidML backend" ${PLAIDML_FOUND}) option(NGRAPH_PLAIDML_ENABLE "Enable the PlaidML backend" ${PLAIDML_FOUND})
if (NGRAPH_GPUH_ENABLE) if (NGRAPH_GPUH_ENABLE)
set(NGRAPH_GPU_ENABLE TRUE) set(NGRAPH_GPU_ENABLE TRUE)
endif() endif()
if (NGRAPH_DISTRIBUTED_MLSL_ENABLE AND NGRAPH_DISTRIBUTED_OMPI_ENABLE)
message(FATAL_ERROR
"Does not support the use of two distributed libraries simultaneously.\n"
"If CPU only backend recommend Intel MLSL by setting NGRAPH_DISTRIBUTED_MLSL_ENABLE flag to true.\n"
"For all other backends use OpenMPI by setting NGRAPH_DISTRIBUTED_OMPI_ENABLE flag to true.\n")
elseif(NGRAPH_DISTRIBUTED_MLSL_ENABLE OR NGRAPH_DISTRIBUTED_OMPI_ENABLE)
set(NGRAPH_DISTRIBUTED_ENABLE TRUE)
else()
set(NGRAPH_DISTRIBUTED_ENABLE FALSE)
endif()
if (NGRAPH_ONNX_IMPORT_ENABLE) if (NGRAPH_ONNX_IMPORT_ENABLE)
option(NGRAPH_USE_SYSTEM_PROTOBUF "Use system provided Protobuf shared object" FALSE) option(NGRAPH_USE_SYSTEM_PROTOBUF "Use system provided Protobuf shared object" FALSE)
option(NGRAPH_ONNXIFI_ENABLE "Enable ONNX Interface for Framework Integration" TRUE) option(NGRAPH_ONNXIFI_ENABLE "Enable ONNX Interface for Framework Integration" TRUE)
...@@ -138,7 +150,8 @@ message(STATUS "NGRAPH_INTERPRETER_ENABLE: ${NGRAPH_INTERPRETER_ENABLE}") ...@@ -138,7 +150,8 @@ message(STATUS "NGRAPH_INTERPRETER_ENABLE: ${NGRAPH_INTERPRETER_ENABLE}")
message(STATUS "NGRAPH_NOP_ENABLE: ${NGRAPH_NOP_ENABLE}") message(STATUS "NGRAPH_NOP_ENABLE: ${NGRAPH_NOP_ENABLE}")
message(STATUS "NGRAPH_GPUH_ENABLE: ${NGRAPH_GPUH_ENABLE}") message(STATUS "NGRAPH_GPUH_ENABLE: ${NGRAPH_GPUH_ENABLE}")
message(STATUS "NGRAPH_GENERIC_CPU_ENABLE: ${NGRAPH_GENERIC_CPU_ENABLE}") message(STATUS "NGRAPH_GENERIC_CPU_ENABLE: ${NGRAPH_GENERIC_CPU_ENABLE}")
message(STATUS "NGRAPH_DISTRIBUTED_ENABLE: ${NGRAPH_DISTRIBUTED_ENABLE}") message(STATUS "NGRAPH_DISTRIBUTED_MLSL_ENABLE: ${NGRAPH_DISTRIBUTED_MLSL_ENABLE}")
message(STATUS "NGRAPH_DISTRIBUTED_OMPI_ENABLE: ${NGRAPH_DISTRIBUTED_OMPI_ENABLE}")
message(STATUS "NGRAPH_DEBUG_ENABLE: ${NGRAPH_DEBUG_ENABLE}") message(STATUS "NGRAPH_DEBUG_ENABLE: ${NGRAPH_DEBUG_ENABLE}")
message(STATUS "NGRAPH_ONNX_IMPORT_ENABLE: ${NGRAPH_ONNX_IMPORT_ENABLE}") message(STATUS "NGRAPH_ONNX_IMPORT_ENABLE: ${NGRAPH_ONNX_IMPORT_ENABLE}")
message(STATUS "NGRAPH_DEX_ONLY: ${NGRAPH_DEX_ONLY}") message(STATUS "NGRAPH_DEX_ONLY: ${NGRAPH_DEX_ONLY}")
...@@ -260,6 +273,15 @@ if (NGRAPH_PLAIDML_ENABLE) ...@@ -260,6 +273,15 @@ if (NGRAPH_PLAIDML_ENABLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_PlaidML_ENABLE") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_PlaidML_ENABLE")
endif() endif()
if (NGRAPH_DISTRIBUTED_ENABLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_DISTRIBUTED_ENABLE")
if (NGRAPH_DISTRIBUTED_MLSL_ENABLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_DISTRIBUTED_MLSL_ENABLE")
elseif (NGRAPH_DISTRIBUTED_OMPI_ENABLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_DISTRIBUTED_OMPI_ENABLE")
endif()
endif()
if (NOT DEFINED NGRAPH_TBB_ENABLE) if (NOT DEFINED NGRAPH_TBB_ENABLE)
set(NGRAPH_TBB_ENABLE ${NGRAPH_CPU_ENABLE}) set(NGRAPH_TBB_ENABLE ${NGRAPH_CPU_ENABLE})
endif() endif()
...@@ -336,11 +358,11 @@ if (WIN32 OR APPLE) ...@@ -336,11 +358,11 @@ if (WIN32 OR APPLE)
else() else()
include(cmake/external_tbb.cmake) include(cmake/external_tbb.cmake)
endif() endif()
if (NGRAPH_DISTRIBUTED_ENABLE)
if (NGRAPH_DISTRIBUTED_MLSL_ENABLE)
include(cmake/external_mlsl.cmake) include(cmake/external_mlsl.cmake)
endif() endif()
if (NGRAPH_HALIDE) if (NGRAPH_HALIDE)
message(WARNING "Halide build system integration is currently using an older LLVM release \ message(WARNING "Halide build system integration is currently using an older LLVM release \
and is not expected to work across most build environments. Consider \ and is not expected to work across most build environments. Consider \
......
...@@ -19,7 +19,7 @@ add_dependencies(mnist_mlp ngraph cpu_backend) ...@@ -19,7 +19,7 @@ add_dependencies(mnist_mlp ngraph cpu_backend)
target_link_libraries(mnist_mlp ngraph cpu_backend) target_link_libraries(mnist_mlp ngraph cpu_backend)
if (NGRAPH_DISTRIBUTED_ENABLE) if (NGRAPH_DISTRIBUTED_ENABLE)
add_executable(dist_mnist_mlp mnist_loader.cpp dist_mnist_mlp.cpp) add_executable(dist_mnist_mlp mnist_loader.cpp dist_mnist_mlp.cpp)
target_compile_definitions(dist_mnist_mlp PRIVATE NGRAPH_DISTRIBUTED) target_compile_definitions(dist_mnist_mlp PRIVATE NGRAPH_DISTRIBUTED_ENABLE)
target_include_directories(dist_mnist_mlp SYSTEM PRIVATE libmlsl) target_include_directories(dist_mnist_mlp SYSTEM PRIVATE libmlsl)
target_link_libraries(dist_mnist_mlp ngraph cpu_backend libmlsl) target_link_libraries(dist_mnist_mlp ngraph cpu_backend libmlsl)
endif() endif()
...@@ -168,6 +168,7 @@ set (SRC ...@@ -168,6 +168,7 @@ set (SRC
placement.cpp placement.cpp
cpio.cpp cpio.cpp
) )
if(NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
list(APPEND SRC distributed.cpp) list(APPEND SRC distributed.cpp)
endif() endif()
...@@ -178,9 +179,16 @@ add_library(ngraph SHARED ${SRC}) ...@@ -178,9 +179,16 @@ add_library(ngraph SHARED ${SRC})
if(NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
target_sources(ngraph PRIVATE distributed.cpp) target_sources(ngraph PRIVATE distributed.cpp)
target_compile_definitions(ngraph PRIVATE NGRAPH_DISTRIBUTED) if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_include_directories(ngraph SYSTEM PRIVATE libmlsl) target_include_directories(ngraph SYSTEM PRIVATE libmlsl)
target_link_libraries(ngraph PRIVATE libmlsl) target_link_libraries(ngraph PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(ngraph SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(ngraph PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif() endif()
add_subdirectory(frontend) add_subdirectory(frontend)
......
...@@ -47,9 +47,15 @@ if ((NGRAPH_GPU_ENABLE OR NGRAPH_CPU_ENABLE) AND NOT NGRAPH_DEX_ONLY) ...@@ -47,9 +47,15 @@ if ((NGRAPH_GPU_ENABLE OR NGRAPH_CPU_ENABLE) AND NOT NGRAPH_DEX_ONLY)
list(APPEND HEADER_SEARCH_DEFINES NGRAPH_HEADERS_PATH="${NGRAPH_INCLUDE_PATH}") list(APPEND HEADER_SEARCH_DEFINES NGRAPH_HEADERS_PATH="${NGRAPH_INCLUDE_PATH}")
if(NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
get_target_property(MLSL_INCLUDE_DIR libmlsl INTERFACE_INCLUDE_DIRECTORIES) if (NGRAPH_DISTRIBUTED_MLSL_ENABLE)
list(APPEND HEADER_SEARCH_DEFINES MLSL_HEADER_PATH="${MLSL_INCLUDE_DIR}") get_target_property(MLSL_INCLUDE_DIR libmlsl INTERFACE_INCLUDE_DIRECTORIES)
add_definitions(-DNGRAPH_DISTRIBUTED) list(APPEND HEADER_SEARCH_DEFINES MLSL_HEADER_PATH="${MLSL_INCLUDE_DIR}")
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
add_definitions(-DMPI_HEADER_PATH="${MPI_PATH}")
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif() endif()
if(NGRAPH_GPU_ENABLE) if(NGRAPH_GPU_ENABLE)
......
...@@ -472,8 +472,14 @@ void codegen::CompilerCore::configure_search_path() ...@@ -472,8 +472,14 @@ void codegen::CompilerCore::configure_search_path()
add_header_search_path(CUDNN_HEADER_PATHS); add_header_search_path(CUDNN_HEADER_PATHS);
#endif #endif
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
add_header_search_path(MLSL_HEADER_PATH); add_header_search_path(MLSL_HEADER_PATH);
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
add_header_search_path(MPI_HEADER_PATH);
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
#endif #endif
} }
......
...@@ -14,37 +14,90 @@ ...@@ -14,37 +14,90 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp> #include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/distributed.hpp" #include "ngraph/distributed.hpp"
#include "ngraph/log.hpp"
using namespace ngraph; using namespace ngraph;
ngraph::Distributed::Distributed() ngraph::Distributed::Distributed()
{ {
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (!MLSL::Environment::GetEnv().IsInitialized()) if (!MLSL::Environment::GetEnv().IsInitialized())
{ {
MLSL::Environment::GetEnv().Init(nullptr, nullptr); MLSL::Environment::GetEnv().Init(nullptr, nullptr);
this_init_comm = true;
}
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI_Init(NULL, NULL);
this_init_comm = true;
} }
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
} }
ngraph::Distributed::~Distributed() ngraph::Distributed::~Distributed()
{ {
if (this_init_comm == true)
{
finalize();
}
}
void ngraph::Distributed::finalize()
{
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (MLSL::Environment::GetEnv().IsInitialized()) if (MLSL::Environment::GetEnv().IsInitialized())
{ {
MLSL::Environment::GetEnv().Finalize(); MLSL::Environment::GetEnv().Finalize();
} }
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int flag = 0;
MPI_Initialized(&flag);
if (flag)
{
MPI_Finalize();
}
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
} }
size_t ngraph::Distributed::get_size() const int ngraph::Distributed::get_size() const
{ {
return MLSL::Environment::GetEnv().GetProcessCount(); #ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessCount());
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
return size;
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
} }
size_t ngraph::Distributed::get_rank() const int ngraph::Distributed::get_rank() const
{ {
return MLSL::Environment::GetEnv().GetProcessIdx(); #ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
return static_cast<int>(MLSL::Environment::GetEnv().GetProcessIdx());
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
} }
#endif #endif
...@@ -25,7 +25,11 @@ namespace ngraph ...@@ -25,7 +25,11 @@ namespace ngraph
public: public:
Distributed(); Distributed();
~Distributed(); ~Distributed();
size_t get_size() const; int get_size() const;
size_t get_rank() const; int get_rank() const;
private:
bool this_init_comm;
void finalize();
}; };
} }
...@@ -202,9 +202,18 @@ if (NGRAPH_CPU_ENABLE) ...@@ -202,9 +202,18 @@ if (NGRAPH_CPU_ENABLE)
target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS) target_compile_definitions(cpu_backend PRIVATE CPU_BACKEND_DLL_EXPORTS)
if(NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
target_compile_definitions(cpu_backend PRIVATE NGRAPH_DISTRIBUTED) if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_include_directories(cpu_backend SYSTEM PRIVATE libmlsl) target_include_directories(cpu_backend SYSTEM PRIVATE libmlsl)
target_link_libraries(cpu_backend PRIVATE libmlsl) target_link_libraries(cpu_backend PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(cpu_backend
SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(cpu_backend
PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif() endif()
add_dependencies(cpu_backend ext_mkldnn ext_eigen) add_dependencies(cpu_backend ext_mkldnn ext_eigen)
......
...@@ -13,9 +13,13 @@ ...@@ -13,9 +13,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp> #include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
...@@ -37,6 +41,8 @@ namespace ngraph ...@@ -37,6 +41,8 @@ namespace ngraph
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto count = static_cast<int>(out[0].get_size()); auto count = static_cast<int>(out[0].get_size());
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = MLSL::DT_FLOAT; auto data_type = MLSL::DT_FLOAT;
if (args[0].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32)
...@@ -54,7 +60,26 @@ namespace ngraph ...@@ -54,7 +60,26 @@ namespace ngraph
arg_tensor, out_tensor, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA); arg_tensor, out_tensor, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA);
ctx->mlsl_env->Wait(req); ctx->mlsl_env->Wait(req);
}; };
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = MPI_FLOAT;
if (args[0].get_element_type() == element::f32)
{
data_type = MPI_FLOAT;
}
else if (args[0].get_element_type() == element::f64)
{
data_type = MPI_DOUBLE;
}
auto functor = [&, count, data_type](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
MPI_Allreduce(
arg_tensor, out_tensor, count, data_type, MPI_SUM, MPI_COMM_WORLD);
};
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -103,6 +103,11 @@ ...@@ -103,6 +103,11 @@
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#include "ngraph/op/allreduce.hpp"
#endif
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_tracing.hpp" #include "ngraph/runtime/cpu/cpu_tracing.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp> #include <mlsl.hpp>
#endif #endif
...@@ -144,10 +144,12 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context() ...@@ -144,10 +144,12 @@ void runtime::cpu::CPU_CallFrame::setup_runtime_context()
ctx->c = new tbb::global_control(tbb::global_control::max_allowed_parallelism, parallelism); ctx->c = new tbb::global_control(tbb::global_control::max_allowed_parallelism, parallelism);
} }
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
NGRAPH_ASSERT(MLSL::Environment::GetEnv().IsInitialized()); if (MLSL::Environment::GetEnv().IsInitialized())
ctx->mlsl_env = &MLSL::Environment::GetEnv(); {
ctx->mlsl_dist = ctx->mlsl_env->CreateDistribution(ctx->mlsl_env->GetProcessCount(), 1); ctx->mlsl_env = &MLSL::Environment::GetEnv();
ctx->mlsl_dist = ctx->mlsl_env->CreateDistribution(ctx->mlsl_env->GetProcessCount(), 1);
}
#endif #endif
} }
...@@ -175,7 +177,8 @@ void runtime::cpu::CPU_CallFrame::cleanup_runtime_context() ...@@ -175,7 +177,8 @@ void runtime::cpu::CPU_CallFrame::cleanup_runtime_context()
} }
delete ctx->c; delete ctx->c;
} }
#ifdef NGRAPH_DISTRIBUTED
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
if (MLSL::Environment::GetEnv().IsInitialized() && ctx->mlsl_dist != nullptr) if (MLSL::Environment::GetEnv().IsInitialized() && ctx->mlsl_dist != nullptr)
{ {
ctx->mlsl_env->DeleteDistribution(ctx->mlsl_dist); ctx->mlsl_env->DeleteDistribution(ctx->mlsl_dist);
......
...@@ -123,9 +123,12 @@ ...@@ -123,9 +123,12 @@
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp> #include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#endif #endif
...@@ -196,11 +199,12 @@ namespace ngraph ...@@ -196,11 +199,12 @@ namespace ngraph
writer.block_end(); writer.block_end();
} }
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::AllReduce) void CPU_Emitter::EMITTER_DECL(ngraph::op::AllReduce)
{ {
const element::Type& element_type = args[0].get_element_type(); const element::Type& element_type = args[0].get_element_type();
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = "MLSL::DT_FLOAT"; auto data_type = "MLSL::DT_FLOAT";
if (element_type == element::f32) if (element_type == element::f32)
...@@ -218,6 +222,26 @@ namespace ngraph ...@@ -218,6 +222,26 @@ namespace ngraph
<< data_type << ", MLSL::RT_SUM, MLSL::GT_DATA);\n"; << data_type << ", MLSL::RT_SUM, MLSL::GT_DATA);\n";
writer << "ctx->mlsl_env->Wait(req);\n"; writer << "ctx->mlsl_env->Wait(req);\n";
writer.block_end(); writer.block_end();
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = "MPI_FLOAT";
if (element_type == element::f32)
{
data_type = "MPI_FLOAT";
}
else if (element_type == element::f64)
{
data_type = "MPI_DOUBLE";
}
writer.block_begin();
writer << "MPI_Allreduce(" << args[0].get_name() << ", " << out[0].get_name()
<< ", " << out[0].get_size() << ", " << data_type
<< ", MPI_SUM, MPI_COMM_WORLD);\n";
writer.block_end();
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
} }
#endif #endif
......
...@@ -174,7 +174,7 @@ ...@@ -174,7 +174,7 @@
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp" #include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp" #include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#endif #endif
...@@ -283,7 +283,7 @@ static StaticInitializers s_static_initializers(s_output_dir); ...@@ -283,7 +283,7 @@ static StaticInitializers s_static_initializers(s_output_dir);
static const runtime::cpu::OpMap dispatcher{ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::emit<op::Add>}, {TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::emit<op::Add>},
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
{TI(ngraph::op::AllReduce), &runtime::cpu::CPU_Emitter::emit<op::AllReduce>}, {TI(ngraph::op::AllReduce), &runtime::cpu::CPU_Emitter::emit<op::AllReduce>},
#endif #endif
{TI(ngraph::op::MatmulBias), &runtime::cpu::CPU_Emitter::emit<op::MatmulBias>}, {TI(ngraph::op::MatmulBias), &runtime::cpu::CPU_Emitter::emit<op::MatmulBias>},
...@@ -471,9 +471,15 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -471,9 +471,15 @@ void runtime::cpu::CPU_ExternalFunction::compile()
writer << "#include <tbb/flow_graph.h>"; writer << "#include <tbb/flow_graph.h>";
} }
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
writer << "#define NGRAPH_DISTRIBUTED_ENABLE\n";
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
writer << "#include <mlsl.hpp>\n"; writer << "#include <mlsl.hpp>\n";
writer << "#define NGRAPH_DISTRIBUTED\n"; writer << "#define NGRAPH_DISTRIBUTED_MLSL_ENABLE\n";
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
writer << "#include <mpi.h>\n";
writer << "#define NGRAPH_DISTRIBUTED_OMPI_ENABLE\n";
#endif
#endif #endif
writer += writer +=
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tbb/global_control.h> #include <tbb/global_control.h>
#include <tbb/task_scheduler_init.h> #include <tbb/task_scheduler_init.h>
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp> #include <mlsl.hpp>
#endif #endif
...@@ -69,7 +69,7 @@ namespace ngraph ...@@ -69,7 +69,7 @@ namespace ngraph
State* const* states; State* const* states;
std::set<size_t> breakpoints; std::set<size_t> breakpoints;
size_t pc; size_t pc;
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
MLSL::Environment* mlsl_env; MLSL::Environment* mlsl_env;
MLSL::Distribution* mlsl_dist; MLSL::Distribution* mlsl_dist;
#endif #endif
......
...@@ -139,7 +139,7 @@ ...@@ -139,7 +139,7 @@
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/rng_state.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/runtime/reference/allreduce.hpp" #include "ngraph/runtime/reference/allreduce.hpp"
#endif #endif
...@@ -255,7 +255,7 @@ private: ...@@ -255,7 +255,7 @@ private:
break; break;
} }
case OP_TYPEID::AllReduce: { case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])), reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])),
static_cast<T*>(out[0]), static_cast<T*>(out[0]),
node.get_input_element_type(0), node.get_input_element_type(0),
......
...@@ -25,9 +25,18 @@ if (NGRAPH_INTERPRETER_ENABLE) ...@@ -25,9 +25,18 @@ if (NGRAPH_INTERPRETER_ENABLE)
set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) set_target_properties(interpreter_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
if(NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
target_compile_definitions(interpreter_backend PRIVATE NGRAPH_DISTRIBUTED) if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_include_directories(interpreter_backend SYSTEM PRIVATE libmlsl) target_include_directories(interpreter_backend SYSTEM PRIVATE libmlsl)
target_link_libraries(interpreter_backend PRIVATE libmlsl) target_link_libraries(interpreter_backend PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(interpreter_backend
SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(interpreter_backend
PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif() endif()
install(TARGETS interpreter_backend install(TARGETS interpreter_backend
......
...@@ -132,7 +132,7 @@ ...@@ -132,7 +132,7 @@
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp" #include "ngraph/state/rng_state.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/runtime/reference/allreduce.hpp" #include "ngraph/runtime/reference/allreduce.hpp"
#endif #endif
...@@ -251,7 +251,7 @@ private: ...@@ -251,7 +251,7 @@ private:
break; break;
} }
case OP_TYPEID::AllReduce: { case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])), reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])),
static_cast<T*>(out[0]), static_cast<T*>(out[0]),
node.get_input_element_type(0), node.get_input_element_type(0),
......
...@@ -16,10 +16,12 @@ ...@@ -16,10 +16,12 @@
#pragma once #pragma once
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
#include <mlsl.hpp> #include <mlsl.hpp>
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
#include <mpi.h>
#endif
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
...@@ -31,6 +33,7 @@ namespace ngraph ...@@ -31,6 +33,7 @@ namespace ngraph
template <typename T> template <typename T>
void allreduce(T* arg, T* out, const element::Type element_type, int count) void allreduce(T* arg, T* out, const element::Type element_type, int count)
{ {
#ifdef NGRAPH_DISTRIBUTED_MLSL_ENABLE
auto data_type = MLSL::DT_FLOAT; auto data_type = MLSL::DT_FLOAT;
if (element_type == element::f32) if (element_type == element::f32)
...@@ -52,6 +55,26 @@ namespace ngraph ...@@ -52,6 +55,26 @@ namespace ngraph
arg, out, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA); arg, out, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA);
env.Wait(req); env.Wait(req);
env.DeleteDistribution(distribution); env.DeleteDistribution(distribution);
#elif NGRAPH_DISTRIBUTED_OMPI_ENABLE
auto data_type = MPI_FLOAT;
if (element_type == element::f32)
{
data_type = MPI_FLOAT;
}
else if (element_type == element::f64)
{
data_type = MPI_DOUBLE;
}
else
{
throw std::runtime_error("AllReduce op supports only f32 and f64 types");
}
MPI_Allreduce(arg, out, count, data_type, MPI_SUM, MPI_COMM_WORLD);
#else
throw ngraph_error("Distributed Library not supported/mentioned");
#endif
} }
} }
} }
......
...@@ -24,34 +24,41 @@ add_executable(nbench ${SRC}) ...@@ -24,34 +24,41 @@ add_executable(nbench ${SRC})
if (APPLE) if (APPLE)
set_property(TARGET nbench APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-rpath,@loader_path/../lib") set_property(TARGET nbench APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-rpath,@loader_path/../lib")
endif() endif()
target_link_libraries(nbench ngraph) target_link_libraries(nbench PRIVATE ngraph)
# if (WIN32) # if (WIN32)
# set_target_properties(nbench # set_target_properties(nbench
# PROPERTIES # PROPERTIES
# LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR}) # LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
# endif() # endif()
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
target_link_libraries(nbench cpu_backend) target_link_libraries(nbench PRIVATE cpu_backend)
endif() endif()
if (NGRAPH_INTELGPU_ENABLE) if (NGRAPH_INTELGPU_ENABLE)
target_link_libraries(nbench intelgpu_backend) target_link_libraries(nbench PRIVATE intelgpu_backend)
endif() endif()
if (NGRAPH_GPU_ENABLE) if (NGRAPH_GPU_ENABLE)
target_link_libraries(nbench gpu_backend) target_link_libraries(nbench PRIVATE gpu_backend)
endif() endif()
if (NGRAPH_INTERPRETER_ENABLE) if (NGRAPH_INTERPRETER_ENABLE)
target_link_libraries(nbench interpreter_backend) target_link_libraries(nbench PRIVATE interpreter_backend)
endif() endif()
if (NGRAPH_PLAIDML_ENABLE) if (NGRAPH_PLAIDML_ENABLE)
target_link_libraries(nbench plaidml_backend) target_link_libraries(nbench PRIVATE plaidml_backend)
endif() endif()
if (NGRAPH_GENERIC_CPU_ENABLE) if (NGRAPH_GENERIC_CPU_ENABLE)
target_link_libraries(nbench gcpu_backend) target_link_libraries(nbench PRIVATE gcpu_backend)
endif() endif()
if (NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
target_compile_definitions(nbench PRIVATE NGRAPH_DISTRIBUTED) if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_link_libraries(nbench libmlsl) target_link_libraries(nbench PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(nbench SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(nbench PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif() endif()
install(TARGETS nbench RUNTIME DESTINATION ${NGRAPH_INSTALL_BIN}) install(TARGETS nbench RUNTIME DESTINATION ${NGRAPH_INSTALL_BIN})
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED #if defined NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/distributed.hpp" #include "ngraph/distributed.hpp"
#endif #endif
...@@ -298,8 +298,12 @@ OPTIONS ...@@ -298,8 +298,12 @@ OPTIONS
return 1; return 1;
} }
#ifdef NGRAPH_DISTRIBUTED #if defined NGRAPH_DISTRIBUTED_ENABLE
ngraph::Distributed dist; unique_ptr<ngraph::Distributed> dist(new ngraph::Distributed());
if (dist->get_size() == 1)
{
dist.reset();
}
#endif #endif
vector<string> models; vector<string> models;
...@@ -424,5 +428,12 @@ OPTIONS ...@@ -424,5 +428,12 @@ OPTIONS
print_results(aggregate_perf_data, timing_detail); print_results(aggregate_perf_data, timing_detail);
} }
#if defined NGRAPH_DISTRIBUTED_ENABLE
if (dist)
{
dist.reset();
}
#endif
return rc; return rc;
} }
...@@ -60,6 +60,10 @@ set(SRC ...@@ -60,6 +60,10 @@ set(SRC
zero_dim_tensor_elimination.cpp zero_dim_tensor_elimination.cpp
) )
if(NGRAPH_DISTRIBUTED_ENABLE)
list(APPEND SRC distributed_setup.cpp)
endif()
set_source_files_properties(includes.cpp PROPERTIES COMPILE_DEFINITIONS set_source_files_properties(includes.cpp PROPERTIES COMPILE_DEFINITIONS
NGRAPH_INCLUDES="${PROJECT_SOURCE_DIR}/src/ngraph") NGRAPH_INCLUDES="${PROJECT_SOURCE_DIR}/src/ngraph")
...@@ -184,8 +188,16 @@ if(NGRAPH_ADDRESS_SANITIZER) ...@@ -184,8 +188,16 @@ if(NGRAPH_ADDRESS_SANITIZER)
endif() endif()
if(NGRAPH_DISTRIBUTED_ENABLE) if(NGRAPH_DISTRIBUTED_ENABLE)
target_compile_definitions(unit-test PRIVATE NGRAPH_DISTRIBUTED) if(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
target_link_libraries(unit-test PRIVATE libmlsl) target_link_libraries(unit-test PRIVATE libmlsl)
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
target_include_directories(unit-test
SYSTEM PRIVATE ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
target_link_libraries(unit-test PRIVATE ${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
endif() endif()
target_link_libraries(unit-test PRIVATE ngraph_test_util) target_link_libraries(unit-test PRIVATE ngraph_test_util)
......
...@@ -242,8 +242,11 @@ NGRAPH_TEST(${BACKEND_NAME}, divide_by_zero_float32) ...@@ -242,8 +242,11 @@ NGRAPH_TEST(${BACKEND_NAME}, divide_by_zero_float32)
std::numeric_limits<float>::infinity()}), std::numeric_limits<float>::infinity()}),
read_vector<float>(result)); read_vector<float>(result));
} }
#ifdef NGRAPH_DISTRIBUTED_OMPI_ENABLE
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_divide_by_zero_int32)
#else
NGRAPH_TEST(${BACKEND_NAME}, divide_by_zero_int32) NGRAPH_TEST(${BACKEND_NAME}, divide_by_zero_int32)
#endif
{ {
Shape shape{2, 2}; Shape shape{2, 2};
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <mlsl.hpp>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "distributed_setup.hpp"
#include "ngraph/distributed.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
...@@ -31,23 +31,27 @@ using namespace ngraph; ...@@ -31,23 +31,27 @@ using namespace ngraph;
TEST(distributed_${BACKEND_NAME}, allreduce) TEST(distributed_${BACKEND_NAME}, allreduce)
{ {
auto shape = Shape{2, 2}; DistributedSetup distsetup;
auto A = make_shared<op::Parameter>(element::f32, shape); auto comm_size = distsetup.get_comm_size();
auto f = make_shared<Function>(make_shared<op::AllReduce>(A), ParameterVector{A}); if (comm_size > 1)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::AllReduce>(A), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto comm_size = MLSL::Environment::GetEnv().GetProcessCount();
auto v = vector<float>{1, 2, 3, 4}; auto v = vector<float>{1, 2, 3, 4};
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4}); copy_data(a, vector<float>{1, 2, 3, 4});
auto result = backend->create_tensor(element::f32, shape); auto result = backend->create_tensor(element::f32, shape);
std::transform( std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size)); v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
auto handle = backend->compile(f); auto handle = backend->compile(f);
backend->call_with_validate(handle, {result}, {a}); backend->call_with_validate(handle, {result}, {a});
EXPECT_EQ(v, read_vector<float>(result)); EXPECT_EQ(v, read_vector<float>(result));
}
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "distributed_setup.hpp"
void DistributedSetup::set_comm_size(int comm_size)
{
ngraph_dist_setup::distributed_comm_size = comm_size;
}
void DistributedSetup::set_comm_rank(int comm_rank)
{
ngraph_dist_setup::distributed_comm_rank = comm_rank;
}
int DistributedSetup::get_comm_size()
{
return ngraph_dist_setup::distributed_comm_size;
}
int DistributedSetup::get_comm_rank()
{
return ngraph_dist_setup::distributed_comm_rank;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iostream>
namespace ngraph_dist_setup
{
static int distributed_comm_size;
static int distributed_comm_rank;
}
class DistributedSetup
{
public:
int get_comm_size();
int get_comm_rank();
void set_comm_size(int comm_size);
void set_comm_rank(int comm_rank);
};
...@@ -22,15 +22,25 @@ ...@@ -22,15 +22,25 @@
using namespace std; using namespace std;
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
#include "ngraph/distributed.hpp" #include "ngraph/distributed.hpp"
#include "distributed_setup.hpp"
#endif #endif
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED_ENABLE
ngraph::Distributed dist; unique_ptr<ngraph::Distributed> dist(new ngraph::Distributed());
DistributedSetup distributed_setup;
distributed_setup.set_comm_size(dist->get_size());
distributed_setup.set_comm_rank(dist->get_rank());
if (dist->get_size() == 1)
{
dist.reset();
}
#endif #endif
const char* exclude = "--gtest_filter=-benchmark.*"; const char* exclude = "--gtest_filter=-benchmark.*";
vector<char*> argv_vector; vector<char*> argv_vector;
argv_vector.push_back(argv[0]); argv_vector.push_back(argv[0]);
...@@ -44,5 +54,12 @@ int main(int argc, char** argv) ...@@ -44,5 +54,12 @@ int main(int argc, char** argv)
::testing::InitGoogleTest(&argc, argv_vector.data()); ::testing::InitGoogleTest(&argc, argv_vector.data());
int rc = RUN_ALL_TESTS(); int rc = RUN_ALL_TESTS();
#ifdef NGRAPH_DISTRIBUTED_ENABLE
if (dist)
{
dist.reset();
}
#endif
return rc; return rc;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment