Commit 32b0e7e0 authored by Daiki AMINAKA's avatar Daiki AMINAKA Committed by Scott Cyphers

Support any root id for broadcast distributed (#3052)

parent 41e1182f
......@@ -35,7 +35,8 @@ namespace ngraph
virtual void
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) = 0;
virtual void broadcast(void* in, element::Type_t element_type, size_t count) = 0;
virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
};
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
......@@ -91,7 +91,10 @@ namespace ngraph
env.DeleteDistribution(distribution);
}
void broadcast(void* in, element::Type_t element_type, size_t count) override
void broadcast(void* in,
element::Type_t element_type,
size_t count,
int root_id) override
{
auto data_type = MLSL::DT_FLOAT;
......@@ -107,7 +110,8 @@ namespace ngraph
MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req = distribution->Bcast(in, count, data_type, 0, MLSL::GT_DATA);
MLSL::CommReq* req =
distribution->Bcast(in, count, data_type, root_id, MLSL::GT_DATA);
env.Wait(req);
env.DeleteDistribution(distribution);
}
......
......@@ -41,7 +41,10 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned");
}
void broadcast(void* in, element::Type_t element_type, size_t count) override
void broadcast(void* in,
element::Type_t element_type,
size_t count,
int root_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
......
......@@ -98,7 +98,10 @@ namespace ngraph
MPI_Allreduce(in, out, count, data_type, MPI_SUM, MPI_COMM_WORLD);
}
void broadcast(void* in, element::Type_t element_type, size_t count) override
void broadcast(void* in,
element::Type_t element_type,
size_t count,
int root_id) override
{
auto data_type = MPI_FLOAT;
......@@ -111,7 +114,7 @@ namespace ngraph
throw std::runtime_error(
"BroadcastDistributed op supports only f32 and f64 types");
}
MPI_Bcast(in, count, data_type, 0, MPI_COMM_WORLD);
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
}
protected:
......
......@@ -19,8 +19,9 @@
using namespace std;
using namespace ngraph;
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg)
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg, int root_id)
: Op("BroadcastDistributed", check_single_output_args({arg}))
, m_root_id(root_id)
{
constructor_validate_and_infer_types();
}
......@@ -41,5 +42,10 @@ void op::BroadcastDistributed::validate_and_infer_types()
shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<BroadcastDistributed>(new_args.at(0));
return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id);
}
int op::BroadcastDistributed::get_root_id() const
{
return m_root_id;
}
......@@ -27,12 +27,16 @@ namespace ngraph
class BroadcastDistributed : public Op
{
public:
BroadcastDistributed(const std::shared_ptr<Node>& arg);
BroadcastDistributed(const std::shared_ptr<Node>& arg, int root_id = 0);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const;
private:
const int m_root_id;
};
}
}
......@@ -34,10 +34,12 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto count = static_cast<int>(args[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum();
auto functor = [&, count, data_type, arg_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
auto broadcast = static_cast<const ngraph::op::BroadcastDistributed*>(node);
auto root_id = broadcast->get_root_id();
auto functor = [&, count, data_type, arg_buffer_index, root_id](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
get_distributed_interface()->broadcast(
ctx->buffer_data[arg_buffer_index], data_type, count);
ctx->buffer_data[arg_buffer_index], data_type, count, root_id);
};
functors.emplace_back(functor);
}
......
......@@ -30,6 +30,7 @@
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
......@@ -482,14 +483,18 @@ private:
}
case OP_TYPEID::BroadcastDistributed:
{
const ngraph::op::BroadcastDistributed* broadcast =
static_cast<const ngraph::op::BroadcastDistributed*>(&node);
int rank_ID;
rank_ID = get_distributed_interface()->get_rank();
if (rank_ID == 0)
int root_id = broadcast->get_root_id();
if (rank_ID == root_id)
{
reference::broadcastdistributed<T>(
args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))));
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
}
......@@ -498,7 +503,8 @@ private:
reference::broadcastdistributed<T>(
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0))));
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
}
break;
}
......
......@@ -27,9 +27,12 @@ namespace ngraph
namespace reference
{
template <typename T>
void broadcastdistributed(T* arg, const element::Type_t element_type, int count)
void broadcastdistributed(T* arg,
const element::Type_t element_type,
int count,
int root_id)
{
get_distributed_interface()->broadcast(arg, element_type, count);
get_distributed_interface()->broadcast(arg, element_type, count, root_id);
}
}
}
......
......@@ -59,7 +59,11 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::BroadcastDistributed>(A), ParameterVector{A});
auto comm_size = get_distributed_interface()->get_size();
for (int root_id = 0; root_id < comm_size; ++root_id)
{
auto f = make_shared<Function>(make_shared<op::BroadcastDistributed>(A, root_id),
ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......@@ -68,7 +72,7 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
copy_data(result, vector<float>(4, 0));
auto processIdx = get_distributed_interface()->get_rank();
if (processIdx == 0)
if (processIdx == root_id)
{
copy_data(result, v);
}
......@@ -76,4 +80,5 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
auto handle = backend->compile(f);
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment