Commit 32b0e7e0 authored by Daiki AMINAKA's avatar Daiki AMINAKA Committed by Scott Cyphers

Support any root id for broadcast distributed (#3052)

parent 41e1182f
...@@ -35,7 +35,8 @@ namespace ngraph ...@@ -35,7 +35,8 @@ namespace ngraph
virtual void virtual void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) = 0; all_reduce(void* in, void* out, element::Type_t element_type, size_t count) = 0;
virtual void broadcast(void* in, element::Type_t element_type, size_t count) = 0; virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
}; };
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface); void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
...@@ -91,7 +91,10 @@ namespace ngraph ...@@ -91,7 +91,10 @@ namespace ngraph
env.DeleteDistribution(distribution); env.DeleteDistribution(distribution);
} }
void broadcast(void* in, element::Type_t element_type, size_t count) override void broadcast(void* in,
element::Type_t element_type,
size_t count,
int root_id) override
{ {
auto data_type = MLSL::DT_FLOAT; auto data_type = MLSL::DT_FLOAT;
...@@ -107,7 +110,8 @@ namespace ngraph ...@@ -107,7 +110,8 @@ namespace ngraph
MLSL::Environment& env = MLSL::Environment::GetEnv(); MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1); MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req = distribution->Bcast(in, count, data_type, 0, MLSL::GT_DATA); MLSL::CommReq* req =
distribution->Bcast(in, count, data_type, root_id, MLSL::GT_DATA);
env.Wait(req); env.Wait(req);
env.DeleteDistribution(distribution); env.DeleteDistribution(distribution);
} }
......
...@@ -41,7 +41,10 @@ namespace ngraph ...@@ -41,7 +41,10 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
void broadcast(void* in, element::Type_t element_type, size_t count) override void broadcast(void* in,
element::Type_t element_type,
size_t count,
int root_id) override
{ {
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
......
...@@ -98,7 +98,10 @@ namespace ngraph ...@@ -98,7 +98,10 @@ namespace ngraph
MPI_Allreduce(in, out, count, data_type, MPI_SUM, MPI_COMM_WORLD); MPI_Allreduce(in, out, count, data_type, MPI_SUM, MPI_COMM_WORLD);
} }
void broadcast(void* in, element::Type_t element_type, size_t count) override void broadcast(void* in,
element::Type_t element_type,
size_t count,
int root_id) override
{ {
auto data_type = MPI_FLOAT; auto data_type = MPI_FLOAT;
...@@ -111,7 +114,7 @@ namespace ngraph ...@@ -111,7 +114,7 @@ namespace ngraph
throw std::runtime_error( throw std::runtime_error(
"BroadcastDistributed op supports only f32 and f64 types"); "BroadcastDistributed op supports only f32 and f64 types");
} }
MPI_Bcast(in, count, data_type, 0, MPI_COMM_WORLD); MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
} }
protected: protected:
......
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg) op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg, int root_id)
: Op("BroadcastDistributed", check_single_output_args({arg})) : Op("BroadcastDistributed", check_single_output_args({arg}))
, m_root_id(root_id)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -41,5 +42,10 @@ void op::BroadcastDistributed::validate_and_infer_types() ...@@ -41,5 +42,10 @@ void op::BroadcastDistributed::validate_and_infer_types()
shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<BroadcastDistributed>(new_args.at(0)); return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id);
}
int op::BroadcastDistributed::get_root_id() const
{
return m_root_id;
} }
...@@ -27,12 +27,16 @@ namespace ngraph ...@@ -27,12 +27,16 @@ namespace ngraph
class BroadcastDistributed : public Op class BroadcastDistributed : public Op
{ {
public: public:
BroadcastDistributed(const std::shared_ptr<Node>& arg); BroadcastDistributed(const std::shared_ptr<Node>& arg, int root_id = 0);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const;
private:
const int m_root_id;
}; };
} }
} }
...@@ -34,10 +34,12 @@ namespace ngraph ...@@ -34,10 +34,12 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto count = static_cast<int>(args[0].get_size()); auto count = static_cast<int>(args[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum(); auto data_type = args[0].get_element_type().get_type_enum();
auto functor = [&, count, data_type, arg_buffer_index](CPURuntimeContext* ctx, auto broadcast = static_cast<const ngraph::op::BroadcastDistributed*>(node);
CPUExecutionContext* ectx) { auto root_id = broadcast->get_root_id();
auto functor = [&, count, data_type, arg_buffer_index, root_id](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
get_distributed_interface()->broadcast( get_distributed_interface()->broadcast(
ctx->buffer_data[arg_buffer_index], data_type, count); ctx->buffer_data[arg_buffer_index], data_type, count, root_id);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
...@@ -482,14 +483,18 @@ private: ...@@ -482,14 +483,18 @@ private:
} }
case OP_TYPEID::BroadcastDistributed: case OP_TYPEID::BroadcastDistributed:
{ {
const ngraph::op::BroadcastDistributed* broadcast =
static_cast<const ngraph::op::BroadcastDistributed*>(&node);
int rank_ID; int rank_ID;
rank_ID = get_distributed_interface()->get_rank(); rank_ID = get_distributed_interface()->get_rank();
if (rank_ID == 0) int root_id = broadcast->get_root_id();
if (rank_ID == root_id)
{ {
reference::broadcastdistributed<T>( reference::broadcastdistributed<T>(
args[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(), node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T); auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize); memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
} }
...@@ -498,7 +503,8 @@ private: ...@@ -498,7 +503,8 @@ private:
reference::broadcastdistributed<T>( reference::broadcastdistributed<T>(
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(), node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
} }
break; break;
} }
......
...@@ -27,9 +27,12 @@ namespace ngraph ...@@ -27,9 +27,12 @@ namespace ngraph
namespace reference namespace reference
{ {
template <typename T> template <typename T>
void broadcastdistributed(T* arg, const element::Type_t element_type, int count) void broadcastdistributed(T* arg,
const element::Type_t element_type,
int count,
int root_id)
{ {
get_distributed_interface()->broadcast(arg, element_type, count); get_distributed_interface()->broadcast(arg, element_type, count, root_id);
} }
} }
} }
......
...@@ -59,7 +59,11 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed) ...@@ -59,7 +59,11 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::BroadcastDistributed>(A), ParameterVector{A}); auto comm_size = get_distributed_interface()->get_size();
for (int root_id = 0; root_id < comm_size; ++root_id)
{
auto f = make_shared<Function>(make_shared<op::BroadcastDistributed>(A, root_id),
ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
...@@ -68,7 +72,7 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed) ...@@ -68,7 +72,7 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
copy_data(result, vector<float>(4, 0)); copy_data(result, vector<float>(4, 0));
auto processIdx = get_distributed_interface()->get_rank(); auto processIdx = get_distributed_interface()->get_rank();
if (processIdx == 0) if (processIdx == root_id)
{ {
copy_data(result, v); copy_data(result, v);
} }
...@@ -76,4 +80,5 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed) ...@@ -76,4 +80,5 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {result}); handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result)); EXPECT_EQ(v, read_vector<float>(result));
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment