Commit 0fa17649 authored by Sevin F. Varoglu's avatar Sevin F. Varoglu Committed by Robert Kimball

refactor distributed code (#1619)

* refactor distributed code

* add test file changes

* add distributed class

* update test file
parent 94f901ed
......@@ -18,11 +18,6 @@ add_executable(mnist_mlp mnist_loader.cpp mnist_mlp.cpp)
add_dependencies(mnist_mlp ngraph cpu_backend)
target_link_libraries(mnist_mlp ngraph cpu_backend)
if (NGRAPH_DISTRIBUTED_ENABLE)
find_package(MPI REQUIRED)
add_definitions(-DNGRAPH_DISTRIBUTED)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
link_libraries(${MPI_CXX_LIBRARIES})
add_executable(dist_mnist_mlp mnist_loader.cpp dist_mnist_mlp.cpp)
add_dependencies(dist_mnist_mlp ngraph cpu_backend)
target_link_libraries(dist_mnist_mlp ngraph cpu_backend)
......
......@@ -20,13 +20,13 @@
#include <list>
#include <math.h>
#include <memory>
#include <mpi.h>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <ngraph/autodiff/adjoints.hpp>
#include <ngraph/distributed.hpp>
#include <ngraph/graph_util.hpp>
#include <ngraph/ngraph.hpp>
......@@ -109,7 +109,7 @@ float test_accuracy(MNistDataLoader& loader,
int main(int argc, const char* argv[])
{
MPI::Init();
ngraph::Distributed dist;
size_t epochs = 5;
size_t batch_size = 128;
......@@ -291,7 +291,5 @@ int main(int argc, const char* argv[])
}
}
MPI::Finalize();
return 0;
}
......@@ -22,9 +22,8 @@ To deploy data-parallel training on backends supported by nGraph API, the
:lines: 180-196
:emphasize-lines: 9-12
Also since we are using OpenMPI in this example, we need to initialize and
finalize MPI with ``MPI::Init();`` and ``MPI::Finalize();`` at the beginning
and the end of the code used to deploy to devices; see the `full raw code`_.
We need to initialize and finalize distributed training with ``Distributed`` object;
see the `full raw code`_.
Finally, to run the training using two nGraph devices, invoke :command:`mpirun`.
This will launch two nGraph CPU backends.
......@@ -36,4 +35,4 @@ This will launch two nGraph CPU backends.
.. _OpenMPI: https://www.open-mpi.org/software/ompi/v3.1
.. _full raw code: https://github.com/NervanaSystems/ngraph/blob/master/doc/examples/mnist_mlp/dist_mnist_mlp.cpp
\ No newline at end of file
.. _full raw code: https://github.com/NervanaSystems/ngraph/blob/master/doc/examples/mnist_mlp/dist_mnist_mlp.cpp
......@@ -150,6 +150,15 @@ set (SRC
cpio.cpp
)
if(NGRAPH_DISTRIBUTED_ENABLE)
find_package(MPI REQUIRED)
add_definitions(-DNGRAPH_DISTRIBUTED)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
link_libraries(${MPI_CXX_LIBRARIES})
set (SRC distributed.cpp ${SRC})
endif()
add_subdirectory(frontend)
find_package(Graphviz QUIET)
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/distributed.hpp"
#include <mpi.h>
using namespace ngraph;
ngraph::Distributed::Distributed()
{
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI_Init(NULL, NULL);
}
}
ngraph::Distributed::~Distributed()
{
MPI_Finalize();
}
int ngraph::Distributed::get_size() const
{
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
return size;
}
int ngraph::Distributed::get_rank() const
{
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
#endif
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
namespace ngraph
{
class Distributed
{
public:
Distributed();
~Distributed();
int get_size() const;
int get_rank() const;
};
}
......@@ -36,12 +36,11 @@ TEST(distributed_${BACKEND_NAME}, allreduce)
auto f = make_shared<Function>(make_shared<op::AllReduce>(A), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4};
int comm_size;
MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
auto v = vector<float>{1, 2, 3, 4};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4});
......
......@@ -23,28 +23,14 @@
using namespace std;
#ifdef NGRAPH_DISTRIBUTED
#include <mpi.h>
class MpiEnvironment : public ::testing::Environment
{
protected:
virtual void SetUp()
{
int flag = 0;
MPI_Initialized(&flag);
if (!flag)
{
MPI::Init();
}
}
virtual void TearDown() { MPI::Finalize(); }
virtual ~MpiEnvironment() {}
};
#include "ngraph/distributed.hpp"
#endif
int main(int argc, char** argv)
{
#ifdef NGRAPH_DISTRIBUTED
ngraph::Distributed dist;
#endif
const char* exclude = "--gtest_filter=-benchmark.*";
vector<char*> argv_vector;
argv_vector.push_back(argv[0]);
......@@ -56,9 +42,6 @@ int main(int argc, char** argv)
argc++;
::testing::InitGoogleTest(&argc, argv_vector.data());
#ifdef NGRAPH_DISTRIBUTED
::testing::AddGlobalTestEnvironment(new MpiEnvironment);
#endif
int rc = RUN_ALL_TESTS();
return rc;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment