Unverified Commit b9c5b9d3 authored by Sevin F. Varoglu's avatar Sevin F. Varoglu Committed by GitHub

add AllReduce op and MPI support (#425)

- enable distributed ngraph (MPI)
- add AllReduce op to ngraph core, interpreter and CPU backend
- add AllReduce unit test
parent 62342c4e
......@@ -133,6 +133,16 @@ elseif(NGRAPH_GPU_ENABLE)
message(FATAL_ERROR "GPU was required but CUDA library was not found")
endif()
#-----------------------------------------------------------------------------------------------
# distributed support
#-----------------------------------------------------------------------------------------------
if(NGRAPH_DISTRIBUTED_ENABLE)
find_package(MPI REQUIRED)
if(MPI_CXX_FOUND)
add_definitions(-DNGRAPH_DISTRIBUTED)
endif()
endif()
#-----------------------------------------------------------------------------------------------
# External projects install directory
#-----------------------------------------------------------------------------------------------
......
......@@ -202,6 +202,19 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
set_source_files_properties(codegen/compiler.cpp PROPERTIES COMPILE_DEFINITIONS "${HEADER_SEARCH_DEFINES}")
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
if(NGRAPH_DISTRIBUTED_ENABLE AND MPI_CXX_INCLUDE_PATH)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
# Add sources for distributed ngraph
# and all its dependencies
set(SRC ${SRC}
ops/allreduce.cpp
)
set_property(SOURCE codegen/compiler.cpp APPEND PROPERTY COMPILE_DEFINITIONS
"MPI_HEADER_PATH=\"${MPI_C_INCLUDE_PATH}\";")
endif()
# GPU backend current requires CPU because they share compiler.cpp,
# and compiler.cpp requires MKLDNN
if(NGRAPH_GPU_ENABLE)
......@@ -291,6 +304,10 @@ if(NGRAPH_GPU_ENABLE AND CUDA_LIBRARIES)
target_link_libraries(ngraph PRIVATE ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDNN_LIBRARIES})
endif()
if(MPI_CXX_INCLUDE_PATH)
target_link_libraries(ngraph PRIVATE ${MPI_CXX_LIBRARIES})
endif()
# Argon
if (NGRAPH_ARGON_ENABLE)
target_link_libraries(ngraph PRIVATE ${ARGON_TRANSFORMER_LIB_DIR}/libargon.so)
......
......@@ -400,6 +400,10 @@ void codegen::StaticCompiler::configure_search_path()
// Only needed for GPU backend
add_header_search_path(CUDA_HEADER_PATHS);
#endif
#ifdef NGRAPH_DISTRIBUTED
add_header_search_path(MPI_HEADER_PATH);
#endif
}
void codegen::StaticCompiler::load_headers_from_resource()
......
......@@ -127,3 +127,7 @@
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/ops/allreduce.hpp"
#endif
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/ops/allreduce.hpp"
using namespace std;
using namespace ngraph;
op::AllReduce::AllReduce(const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs("AllReduce", {arg})
{
auto& input = m_inputs.at(0);
set_value_type_checked(
make_shared<TensorViewType>(input.get_element_type(), input.get_shape()));
if ((arg->get_element_type() != element::f32) && (arg->get_element_type() != element::f64))
{
throw ngraph_error("Unsupported data type for AllReduce");
}
}
#endif
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#ifdef NGRAPH_DISTRIBUTED
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
class AllReduce : public RequiresTensorViewArgs
{
public:
AllReduce(const std::shared_ptr<Node>& arg);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<AllReduce>(new_args.at(0));
}
};
}
}
#endif
......@@ -45,8 +45,14 @@
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include <mpi.h>
#include "ngraph/ops/allreduce.hpp"
#endif
using namespace std;
using namespace ngraph;
......@@ -127,6 +133,30 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitAdd)
writer << "}\n";
}
#ifdef NGRAPH_DISTRIBUTED
void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitAllReduce)
{
const element::Type& element_type = args[0].get_element_type();
auto data_type = "MPI_FLOAT";
if (element_type == element::f32)
{
data_type = "MPI_FLOAT";
}
else if (element_type == element::f64)
{
data_type = "MPI_DOUBLE";
}
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer << "MPI_Allreduce(" << args[0].get_name() << ", " << out[0].get_name() << ", "
<< out[0].get_size() << ", " << data_type << ", MPI_SUM, MPI_COMM_WORLD);\n";
writer.indent--;
writer << "}\n";
}
#endif
//TODO: This could be further optimized to reduce the impact of memcpy by either
//a) emitting customized code for initializing output/bias
//b) emitting two cblas calls (one for gemm on W and x and the second for gemm on Bias and E^T + the result of the first gemm)
......
......@@ -42,6 +42,9 @@ namespace ngraph
public:
static void EMITTER_DECL(EmitNop);
static void EMITTER_DECL(EmitAdd);
#ifdef NGRAPH_DISTRIBUTED
static void EMITTER_DECL(EmitAllReduce);
#endif
static void EMITTER_DECL(EmitDot);
static void EMITTER_DECL(EmitMultiply);
static void EMITTER_DECL(EmitGetOutputElement);
......
......@@ -99,6 +99,10 @@
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/ops/allreduce.hpp"
#endif
using namespace std;
using namespace ngraph;
......@@ -151,6 +155,9 @@ static StaticInitializers s_static_initializers;
static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::EmitAdd},
#ifdef NGRAPH_DISTRIBUTED
{TI(ngraph::op::AllReduce), &runtime::cpu::CPU_Emitter::EmitAllReduce},
#endif
{TI(ngraph::op::MatmulBias), &runtime::cpu::CPU_Emitter::EmitMatmulBias},
{TI(ngraph::op::Dot), &runtime::cpu::CPU_Emitter::EmitDot},
{TI(ngraph::op::Multiply), &runtime::cpu::CPU_Emitter::EmitMultiply},
......@@ -290,6 +297,10 @@ using namespace ngraph::runtime;
)";
#ifdef NGRAPH_DISTRIBUTED
writer << "#include <mpi.h>\n\n";
#endif
if (m_use_tbb)
{
writer << "#include <tbb/flow_graph.h>\n";
......
......@@ -96,6 +96,10 @@
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/kernel/allreduce.hpp"
#endif
namespace ngraph
{
namespace runtime
......@@ -235,6 +239,15 @@ private:
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
#ifdef NGRAPH_DISTRIBUTED
else if (node_op == "AllReduce")
{
kernel::allreduce<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_element_type(),
static_cast<int>(args[0]->get_element_count()));
}
#endif
else if (node_op == "Asin")
{
kernel::asin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#ifdef NGRAPH_DISTRIBUTED
#include <mpi.h>
#include "ngraph/types/element_type.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void allreduce(T* arg, T* out, const element::Type element_type, int count)
{
auto data_type = MPI_FLOAT;
if (element_type == element::f32)
{
data_type = MPI_FLOAT;
}
else if (element_type == element::f64)
{
data_type = MPI_DOUBLE;
}
MPI_Allreduce(arg, out, count, data_type, MPI_SUM, MPI_COMM_WORLD);
}
}
}
}
#endif
......@@ -71,6 +71,10 @@
#include "ngraph/ops/tanh.hpp"
#include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/ops/allreduce.hpp"
#endif
using namespace ngraph;
using namespace std;
using json = nlohmann::json;
......@@ -324,6 +328,12 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Add>(args[0], args[1]);
}
#ifdef NGRAPH_DISTRIBUTED
else if (node_op == "AllReduce")
{
node = make_shared<op::AllReduce>(args[0]);
}
#endif
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
......@@ -804,6 +814,9 @@ static json write(const Node& n)
else if (node_op == "Add")
{
}
else if (node_op == "AllReduce")
{
}
else if (node_op == "Asin")
{
}
......
......@@ -111,6 +111,15 @@ if(NGRAPH_ARGON_ENABLE)
set(SRC ${SRC} ${ADDITIONAL_ARGON_TEST})
endif()
if(NGRAPH_DISTRIBUTED_ENABLE AND MPI_C_INCLUDE_PATH)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
foreach(BACKEND_NAME ${BACKEND_NAMES})
configure_file(distributed.cpp distributed_${BACKEND_NAME}.cpp)
set(SRC ${SRC} ${CMAKE_CURRENT_BINARY_DIR}/distributed_${BACKEND_NAME}.cpp)
endforeach()
endif()
foreach(BACKEND_NAME ${BACKEND_NAMES})
configure_file(backend_test.in.cpp backend_test_${BACKEND_NAME}.cpp)
configure_file(convolution_test.in.cpp convolution_test_${BACKEND_NAME}.cpp)
......@@ -144,6 +153,10 @@ endif()
add_executable(unit-test ${SRC})
if(MPI_C_INCLUDE_PATH)
target_link_libraries(unit-test ${MPI_CXX_LIBRARIES})
endif()
if(MKLDNN_INCLUDE_DIR)
target_link_libraries(unit-test mkldnn)
add_dependencies(unit-test ext_mkldnn)
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <fstream>
#include <sstream>
#include <mpi.h>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
#include "util/random.hpp"
using namespace std;
using namespace ngraph;
TEST(distributed_${BACKEND_NAME}, allreduce)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::AllReduce>(A), op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
auto v = vector<float>{1, 2, 3, 4};
int comm_size;
MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
auto a = backend->make_primary_tensor_view(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::f32, shape);
std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
cf->call({a}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
......@@ -22,6 +22,19 @@
using namespace std;
#ifdef NGRAPH_DISTRIBUTED
#include <mpi.h>
class MpiEnvironment : public ::testing::Environment
{
protected:
virtual void SetUp() { MPI::Init(); }
virtual void TearDown() { MPI::Finalize(); }
virtual ~MpiEnvironment() {}
};
#endif
int main(int argc, char** argv)
{
const char* exclude = "--gtest_filter=-benchmark.*";
......@@ -35,6 +48,9 @@ int main(int argc, char** argv)
argc++;
::testing::InitGoogleTest(&argc, argv_vector.data());
#ifdef NGRAPH_DISTRIBUTED
::testing::AddGlobalTestEnvironment(new MpiEnvironment);
#endif
int rc = RUN_ALL_TESTS();
return rc;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment