Commit 563af715 authored by Daiki AMINAKA's avatar Daiki AMINAKA Committed by Scott Cyphers

Support arbitrary reduction op (#3048)

parent 84ba3a2a
...@@ -22,6 +22,41 @@ ...@@ -22,6 +22,41 @@
using namespace ngraph; using namespace ngraph;
NGRAPH_API const reduction::Type reduction::sum(reduction::Type_t::sum);
NGRAPH_API const reduction::Type reduction::prod(reduction::Type_t::prod);
NGRAPH_API const reduction::Type reduction::min(reduction::Type_t::min);
NGRAPH_API const reduction::Type reduction::max(reduction::Type_t::max);
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (obj.get_type())
{
case reduction::Type_t::sum: out << "sum"; break;
case reduction::Type_t::prod: out << "prod"; break;
case reduction::Type_t::min: out << "min"; break;
case reduction::Type_t::max: out << "max"; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
return out;
};
bool reduction::Type::operator==(const reduction::Type& other) const
{
return m_type == other.m_type;
}
reduction::Type_t reduction::Type::get_type() const
{
return m_type;
}
static std::unique_ptr<DistributedInterface> s_distributed_interface; static std::unique_ptr<DistributedInterface> s_distributed_interface;
void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface) void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface)
......
...@@ -24,6 +24,38 @@ ...@@ -24,6 +24,38 @@
namespace ngraph namespace ngraph
{ {
namespace reduction
{
enum class Type_t
{
sum,
prod,
min,
max,
};
class Type
{
public:
Type(const Type_t t)
: m_type(t)
{
}
friend std::ostream& operator<<(std::ostream&, const Type&);
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
Type_t get_type() const;
private:
Type_t m_type;
};
std::ostream& operator<<(std::ostream& out, const Type& obj);
extern NGRAPH_API const Type sum;
extern NGRAPH_API const Type prod;
extern NGRAPH_API const Type min;
extern NGRAPH_API const Type max;
}
class DistributedInterface class DistributedInterface
{ {
public: public:
...@@ -33,8 +65,11 @@ namespace ngraph ...@@ -33,8 +65,11 @@ namespace ngraph
virtual int get_rank() = 0; virtual int get_rank() = 0;
virtual void log_print(const std::string& timestamp, const std::vector<char>& buf) = 0; virtual void log_print(const std::string& timestamp, const std::vector<char>& buf) = 0;
virtual void virtual void all_reduce(void* in,
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) = 0; void* out,
element::Type_t element_type,
reduction::Type reduce_type,
size_t count) = 0;
virtual void virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0; broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
}; };
......
...@@ -65,8 +65,11 @@ namespace ngraph ...@@ -65,8 +65,11 @@ namespace ngraph
std::printf("%s [MLSL RANK: %d]: %s\n", timestamp.c_str(), get_rank(), buf.data()); std::printf("%s [MLSL RANK: %d]: %s\n", timestamp.c_str(), get_rank(), buf.data());
} }
void void all_reduce(void* in,
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override void* out,
element::Type_t element_type,
reduction::Type reduce_type,
size_t count) override
{ {
auto data_type = MLSL::DT_FLOAT; auto data_type = MLSL::DT_FLOAT;
...@@ -83,10 +86,29 @@ namespace ngraph ...@@ -83,10 +86,29 @@ namespace ngraph
throw std::runtime_error("AllReduce op supports only f32 and f64 types"); throw std::runtime_error("AllReduce op supports only f32 and f64 types");
} }
decltype(MLSL::RT_SUM) mlsl_reduce_type;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
{
case reduction::Type_t::sum: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type_t::prod:
throw std::runtime_error("MLSL doesn't support allreduce prod");
break;
case reduction::Type_t::min: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type_t::max: mlsl_reduce_type = MLSL::RT_MAX; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
MLSL::Environment& env = MLSL::Environment::GetEnv(); MLSL::Environment& env = MLSL::Environment::GetEnv();
MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1); MLSL::Distribution* distribution = env.CreateDistribution(env.GetProcessCount(), 1);
MLSL::CommReq* req = MLSL::CommReq* req = distribution->AllReduce(
distribution->AllReduce(in, out, count, data_type, MLSL::RT_SUM, MLSL::GT_DATA); in, out, count, data_type, mlsl_reduce_type, MLSL::GT_DATA);
env.Wait(req); env.Wait(req);
env.DeleteDistribution(distribution); env.DeleteDistribution(distribution);
} }
......
...@@ -35,8 +35,11 @@ namespace ngraph ...@@ -35,8 +35,11 @@ namespace ngraph
{ {
std::printf("%s: %s\n", timestamp.c_str(), buf.data()); std::printf("%s: %s\n", timestamp.c_str(), buf.data());
} }
void void all_reduce(void* in,
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override void* out,
element::Type_t element_type,
reduction::Type reduce_type,
size_t count) override
{ {
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
......
...@@ -77,8 +77,11 @@ namespace ngraph ...@@ -77,8 +77,11 @@ namespace ngraph
"%s [OpenMPI RANK: %d]: %s\n", timestamp.c_str(), get_rank(), buf.data()); "%s [OpenMPI RANK: %d]: %s\n", timestamp.c_str(), get_rank(), buf.data());
} }
void void all_reduce(void* in,
all_reduce(void* in, void* out, element::Type_t element_type, size_t count) override void* out,
element::Type_t element_type,
reduction::Type reduce_type,
size_t count) override
{ {
auto data_type = MPI_FLOAT; auto data_type = MPI_FLOAT;
...@@ -95,7 +98,24 @@ namespace ngraph ...@@ -95,7 +98,24 @@ namespace ngraph
throw std::runtime_error("AllReduce op supports only f32 and f64 types"); throw std::runtime_error("AllReduce op supports only f32 and f64 types");
} }
MPI_Allreduce(in, out, count, data_type, MPI_SUM, MPI_COMM_WORLD); decltype(MPI_SUM) mpi_reduce_type;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
{
case reduction::Type_t::sum: mpi_reduce_type = MPI_SUM; break;
case reduction::Type_t::prod: mpi_reduce_type = MPI_PROD; break;
case reduction::Type_t::min: mpi_reduce_type = MPI_MIN; break;
case reduction::Type_t::max: mpi_reduce_type = MPI_MAX; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
MPI_Allreduce(in, out, count, data_type, mpi_reduce_type, MPI_COMM_WORLD);
} }
void broadcast(void* in, void broadcast(void* in,
......
...@@ -22,11 +22,13 @@ using namespace ngraph; ...@@ -22,11 +22,13 @@ using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"}; const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce() op::AllReduce::AllReduce()
: m_reduce_type(reduction::sum)
{ {
} }
op::AllReduce::AllReduce(const shared_ptr<Node>& arg) op::AllReduce::AllReduce(const shared_ptr<Node>& arg, const reduction::Type reduce_type)
: Op(check_single_output_args({arg})) : Op(check_single_output_args({arg}))
, m_reduce_type(reduce_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -47,5 +49,10 @@ void op::AllReduce::validate_and_infer_types() ...@@ -47,5 +49,10 @@ void op::AllReduce::validate_and_infer_types()
shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<AllReduce>(new_args.at(0)); return make_shared<AllReduce>(new_args.at(0), get_reduce_type());
}
reduction::Type op::AllReduce::get_reduce_type() const
{
return m_reduce_type;
} }
...@@ -30,11 +30,16 @@ namespace ngraph ...@@ -30,11 +30,16 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AllReduce(); AllReduce();
AllReduce(const std::shared_ptr<Node>& arg); AllReduce(const std::shared_ptr<Node>& arg,
const reduction::Type reduce_type = reduction::sum);
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const;
private:
const reduction::Type m_reduce_type;
}; };
} }
} }
...@@ -37,6 +37,9 @@ namespace ngraph ...@@ -37,6 +37,9 @@ namespace ngraph
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto count = static_cast<int>(out[0].get_size()); auto count = static_cast<int>(out[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum(); auto data_type = args[0].get_element_type().get_type_enum();
const ngraph::op::AllReduce* allreduce =
static_cast<const ngraph::op::AllReduce*>(node);
auto reduce_type = allreduce->get_reduce_type();
auto external_function_name = external_function->get_function_name(); auto external_function_name = external_function->get_function_name();
NGRAPH_DEBUG_PRINT( NGRAPH_DEBUG_PRINT(
...@@ -48,11 +51,13 @@ namespace ngraph ...@@ -48,11 +51,13 @@ namespace ngraph
node->get_friendly_name().c_str(), node->get_friendly_name().c_str(),
count); count);
auto functor = [&, count, data_type, arg_buffer_index, out_buffer_index]( auto functor =
[&, count, reduce_type, data_type, arg_buffer_index, out_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
get_distributed_interface()->all_reduce(ctx->buffer_data[arg_buffer_index], get_distributed_interface()->all_reduce(ctx->buffer_data[arg_buffer_index],
ctx->buffer_data[out_buffer_index], ctx->buffer_data[out_buffer_index],
data_type, data_type,
reduce_type,
count); count);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -263,10 +263,13 @@ namespace ngraph ...@@ -263,10 +263,13 @@ namespace ngraph
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::AllReduce) void CPU_Emitter::EMITTER_DECL(ngraph::op::AllReduce)
{ {
const ngraph::op::AllReduce* allreduce =
static_cast<const ngraph::op::AllReduce*>(node);
writer << "ngraph::get_distributed_interface()->all_reduce(" << args[0].get_name() writer << "ngraph::get_distributed_interface()->all_reduce(" << args[0].get_name()
<< ", " << out[0].get_name() << ", " << ", " << out[0].get_name() << ", "
<< "ngraph::element::Type_t::" << args[0].get_element_type().get_type_name() << "ngraph::element::Type_t::" << args[0].get_element_type().get_type_name()
<< ", " << out[0].get_size() << ");\n"; << ", " << out[0].get_size() << ", "
<< "ngraph::Reduce_t::" << allreduce->get_reduce_type() << ");\n";
} }
template <> template <>
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <vector> #include <vector>
#include "ngraph/op/all.hpp" #include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp" #include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
...@@ -254,9 +255,12 @@ private: ...@@ -254,9 +255,12 @@ private:
} }
case OP_TYPEID::AllReduce: case OP_TYPEID::AllReduce:
{ {
const ngraph::op::AllReduce* allreduce =
static_cast<const ngraph::op::AllReduce*>(&node);
reference::allreduce<T>(args[0]->get_data_ptr<T>(), reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(), node.get_input_element_type(0).get_type_enum(),
allreduce->get_reduce_type(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))));
break; break;
} }
......
...@@ -25,9 +25,13 @@ namespace ngraph ...@@ -25,9 +25,13 @@ namespace ngraph
namespace reference namespace reference
{ {
template <typename T> template <typename T>
void allreduce(T* arg, T* out, const element::Type_t element_type, int count) void allreduce(T* arg,
T* out,
const element::Type_t element_type,
const reduction::Type reduce_type,
int count)
{ {
get_distributed_interface()->all_reduce(arg, out, element_type, count); get_distributed_interface()->all_reduce(arg, out, element_type, reduce_type, count);
} }
} }
} }
......
...@@ -29,25 +29,61 @@ ...@@ -29,25 +29,61 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(distributed_${BACKEND_NAME}, allreduce) static void test_allreduce_common(reduction::Type reduce_type)
{ {
auto comm_size = get_distributed_interface()->get_size(); auto comm_size = get_distributed_interface()->get_size();
if (comm_size > 1) if (comm_size > 1)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::AllReduce>(A), ParameterVector{A}); auto f =
make_shared<Function>(make_shared<op::AllReduce>(A, reduce_type), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4}; auto v = vector<float>{1, 2, 3, 4};
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4});
auto result = backend->create_tensor(element::f32, shape); auto result = backend->create_tensor(element::f32, shape);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
{
case reduction::Type_t::sum:
copy_data(a, v);
std::transform( std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size)); v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
break;
case reduction::Type_t::prod:
copy_data(a, v);
std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float {
return pow(elm, comm_size);
});
break;
case reduction::Type_t::min:
case reduction::Type_t::max:
auto shift = get_distributed_interface()->get_rank();
std::rotate(v.begin(), v.begin() + shift % v.size(), v.end());
copy_data(a, v);
if (reduce_type == reduction::Type_t::min)
{
std::fill(v.begin(), v.end(), 1);
for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++)
v[i] = i + 1;
}
else
{
std::fill(v.begin(), v.end(), v.size());
for (int i = 0; i < static_cast<int>(v.size()) - comm_size; i++)
v[i] = i + 2;
}
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
...@@ -55,6 +91,28 @@ TEST(distributed_${BACKEND_NAME}, allreduce) ...@@ -55,6 +91,28 @@ TEST(distributed_${BACKEND_NAME}, allreduce)
} }
} }
TEST(distributed_${BACKEND_NAME}, allreduce_sum)
{
test_allreduce_common(reduction::sum);
}
TEST(distributed_${BACKEND_NAME}, allreduce_min)
{
test_allreduce_common(reduction::min);
}
TEST(distributed_${BACKEND_NAME}, allreduce_max)
{
test_allreduce_common(reduction::max);
}
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod)
{
test_allreduce_common(reduction::prod);
}
#endif
TEST(distributed_${BACKEND_NAME}, broadcastdistributed) TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment