Commit 0fa17649 authored by Sevin F. Varoglu's avatar Sevin F. Varoglu Committed by Robert Kimball

refactor distributed code (#1619)

* refactor distributed code

* add test file changes

* add distributed class

* update test file
parent 94f901ed
...@@ -18,11 +18,6 @@ add_executable(mnist_mlp mnist_loader.cpp mnist_mlp.cpp) ...@@ -18,11 +18,6 @@ add_executable(mnist_mlp mnist_loader.cpp mnist_mlp.cpp)
add_dependencies(mnist_mlp ngraph cpu_backend) add_dependencies(mnist_mlp ngraph cpu_backend)
target_link_libraries(mnist_mlp ngraph cpu_backend) target_link_libraries(mnist_mlp ngraph cpu_backend)
if (NGRAPH_DISTRIBUTED_ENABLE) if (NGRAPH_DISTRIBUTED_ENABLE)
find_package(MPI REQUIRED)
add_definitions(-DNGRAPH_DISTRIBUTED)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
link_libraries(${MPI_CXX_LIBRARIES})
add_executable(dist_mnist_mlp mnist_loader.cpp dist_mnist_mlp.cpp) add_executable(dist_mnist_mlp mnist_loader.cpp dist_mnist_mlp.cpp)
add_dependencies(dist_mnist_mlp ngraph cpu_backend) add_dependencies(dist_mnist_mlp ngraph cpu_backend)
target_link_libraries(dist_mnist_mlp ngraph cpu_backend) target_link_libraries(dist_mnist_mlp ngraph cpu_backend)
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
#include <list> #include <list>
#include <math.h> #include <math.h>
#include <memory> #include <memory>
#include <mpi.h>
#include <random> #include <random>
#include <set> #include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <ngraph/autodiff/adjoints.hpp> #include <ngraph/autodiff/adjoints.hpp>
#include <ngraph/distributed.hpp>
#include <ngraph/graph_util.hpp> #include <ngraph/graph_util.hpp>
#include <ngraph/ngraph.hpp> #include <ngraph/ngraph.hpp>
...@@ -109,7 +109,7 @@ float test_accuracy(MNistDataLoader& loader, ...@@ -109,7 +109,7 @@ float test_accuracy(MNistDataLoader& loader,
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
MPI::Init(); ngraph::Distributed dist;
size_t epochs = 5; size_t epochs = 5;
size_t batch_size = 128; size_t batch_size = 128;
...@@ -291,7 +291,5 @@ int main(int argc, const char* argv[]) ...@@ -291,7 +291,5 @@ int main(int argc, const char* argv[])
} }
} }
MPI::Finalize();
return 0; return 0;
} }
...@@ -22,9 +22,8 @@ To deploy data-parallel training on backends supported by nGraph API, the ...@@ -22,9 +22,8 @@ To deploy data-parallel training on backends supported by nGraph API, the
:lines: 180-196 :lines: 180-196
:emphasize-lines: 9-12 :emphasize-lines: 9-12
Also since we are using OpenMPI in this example, we need to initialize and We need to initialize and finalize distributed training with ``Distributed`` object;
finalize MPI with ``MPI::Init();`` and ``MPI::Finalize();`` at the beginning see the `full raw code`_.
and the end of the code used to deploy to devices; see the `full raw code`_.
Finally, to run the training using two nGraph devices, invoke :command:`mpirun`. Finally, to run the training using two nGraph devices, invoke :command:`mpirun`.
This will launch two nGraph CPU backends. This will launch two nGraph CPU backends.
...@@ -36,4 +35,4 @@ This will launch two nGraph CPU backends. ...@@ -36,4 +35,4 @@ This will launch two nGraph CPU backends.
.. _OpenMPI: https://www.open-mpi.org/software/ompi/v3.1 .. _OpenMPI: https://www.open-mpi.org/software/ompi/v3.1
.. _full raw code: https://github.com/NervanaSystems/ngraph/blob/master/doc/examples/mnist_mlp/dist_mnist_mlp.cpp .. _full raw code: https://github.com/NervanaSystems/ngraph/blob/master/doc/examples/mnist_mlp/dist_mnist_mlp.cpp
\ No newline at end of file
...@@ -150,6 +150,15 @@ set (SRC ...@@ -150,6 +150,15 @@ set (SRC
cpio.cpp cpio.cpp
) )
if(NGRAPH_DISTRIBUTED_ENABLE)
find_package(MPI REQUIRED)
add_definitions(-DNGRAPH_DISTRIBUTED)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
link_libraries(${MPI_CXX_LIBRARIES})
set (SRC distributed.cpp ${SRC})
endif()
add_subdirectory(frontend) add_subdirectory(frontend)
find_package(Graphviz QUIET) find_package(Graphviz QUIET)
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/distributed.hpp"
#include <mpi.h>
using namespace ngraph;
ngraph::Distributed::Distributed()
{
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI_Init(NULL, NULL);
}
}
ngraph::Distributed::~Distributed()
{
MPI_Finalize();
}
int ngraph::Distributed::get_size() const
{
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
return size;
}
int ngraph::Distributed::get_rank() const
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
#endif
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
namespace ngraph
{
class Distributed
{
public:
Distributed();
~Distributed();
int get_size() const;
int get_rank() const;
};
}
...@@ -36,12 +36,11 @@ TEST(distributed_${BACKEND_NAME}, allreduce) ...@@ -36,12 +36,11 @@ TEST(distributed_${BACKEND_NAME}, allreduce)
auto f = make_shared<Function>(make_shared<op::AllReduce>(A), op::ParameterVector{A}); auto f = make_shared<Function>(make_shared<op::AllReduce>(A), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4};
int comm_size; int comm_size;
MPI_Comm_size(MPI_COMM_WORLD, &comm_size); MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
auto v = vector<float>{1, 2, 3, 4};
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4}); copy_data(a, vector<float>{1, 2, 3, 4});
......
...@@ -23,28 +23,14 @@ ...@@ -23,28 +23,14 @@
using namespace std; using namespace std;
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
#include <mpi.h> #include "ngraph/distributed.hpp"
class MpiEnvironment : public ::testing::Environment
{
protected:
virtual void SetUp()
{
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI::Init();
}
}
virtual void TearDown() { MPI::Finalize(); }
virtual ~MpiEnvironment() {}
};
#endif #endif
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
#ifdef NGRAPH_DISTRIBUTED
ngraph::Distributed dist;
#endif
const char* exclude = "--gtest_filter=-benchmark.*"; const char* exclude = "--gtest_filter=-benchmark.*";
vector<char*> argv_vector; vector<char*> argv_vector;
argv_vector.push_back(argv[0]); argv_vector.push_back(argv[0]);
...@@ -56,9 +42,6 @@ int main(int argc, char** argv) ...@@ -56,9 +42,6 @@ int main(int argc, char** argv)
argc++; argc++;
::testing::InitGoogleTest(&argc, argv_vector.data()); ::testing::InitGoogleTest(&argc, argv_vector.data());
#ifdef NGRAPH_DISTRIBUTED
::testing::AddGlobalTestEnvironment(new MpiEnvironment);
#endif
int rc = RUN_ALL_TESTS(); int rc = RUN_ALL_TESTS();
return rc; return rc;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment