Commit d0a83a35 authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Eliminate wrapped enum (#3111)

* Eliminate wrapped enum
Switch allreduce to new op form

* review comments

* review comments

* typo
parent 8a062c4a
...@@ -22,11 +22,6 @@ ...@@ -22,11 +22,6 @@
using namespace ngraph; using namespace ngraph;
NGRAPH_API const reduction::Type reduction::sum(reduction::Type_t::sum);
NGRAPH_API const reduction::Type reduction::prod(reduction::Type_t::prod);
NGRAPH_API const reduction::Type reduction::min(reduction::Type_t::min);
NGRAPH_API const reduction::Type reduction::max(reduction::Type_t::max);
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj) std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{ {
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob ...@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (obj.get_type()) switch (obj)
{ {
case reduction::Type_t::sum: out << "sum"; break; case reduction::Type::SUM: out << "SUM"; break;
case reduction::Type_t::prod: out << "prod"; break; case reduction::Type::PROD: out << "PROD"; break;
case reduction::Type_t::min: out << "min"; break; case reduction::Type::MIN: out << "MIN"; break;
case reduction::Type_t::max: out << "max"; break; case reduction::Type::MAX: out << "MAX"; break;
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
...@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob ...@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
return out; return out;
}; };
bool reduction::Type::operator==(const reduction::Type& other) const
{
return m_type == other.m_type;
}
reduction::Type_t reduction::Type::get_type() const
{
return m_type;
}
static std::unique_ptr<DistributedInterface> s_distributed_interface; static std::unique_ptr<DistributedInterface> s_distributed_interface;
void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface) void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface)
......
...@@ -26,34 +26,15 @@ namespace ngraph ...@@ -26,34 +26,15 @@ namespace ngraph
{ {
namespace reduction namespace reduction
{ {
enum class Type_t enum class Type
{ {
sum, SUM,
prod, PROD,
min, MIN,
max, MAX,
}; };
class Type
{
public:
Type(const Type_t t)
: m_type(t)
{
}
friend std::ostream& operator<<(std::ostream&, const Type&);
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
Type_t get_type() const;
private:
Type_t m_type;
};
std::ostream& operator<<(std::ostream& out, const Type& obj); std::ostream& operator<<(std::ostream& out, const Type& obj);
extern NGRAPH_API const Type sum;
extern NGRAPH_API const Type prod;
extern NGRAPH_API const Type min;
extern NGRAPH_API const Type max;
} }
class DistributedInterface class DistributedInterface
......
...@@ -92,14 +92,14 @@ namespace ngraph ...@@ -92,14 +92,14 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (reduce_type.get_type()) switch (reduce_type)
{ {
case reduction::Type_t::sum: mlsl_reduce_type = MLSL::RT_SUM; break; case reduction::Type::SUM: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type_t::prod: case reduction::Type::PROD:
throw std::runtime_error("MLSL doesn't support allreduce prod"); throw std::runtime_error("MLSL doesn't support allreduce prod");
break; break;
case reduction::Type_t::min: mlsl_reduce_type = MLSL::RT_MIN; break; case reduction::Type::MIN: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type_t::max: mlsl_reduce_type = MLSL::RT_MAX; break; case reduction::Type::MAX: mlsl_reduce_type = MLSL::RT_MAX; break;
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
......
...@@ -104,12 +104,12 @@ namespace ngraph ...@@ -104,12 +104,12 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (reduce_type.get_type()) switch (reduce_type)
{ {
case reduction::Type_t::sum: mpi_reduce_type = MPI_SUM; break; case reduction::Type::SUM: mpi_reduce_type = MPI_SUM; break;
case reduction::Type_t::prod: mpi_reduce_type = MPI_PROD; break; case reduction::Type::PROD: mpi_reduce_type = MPI_PROD; break;
case reduction::Type_t::min: mpi_reduce_type = MPI_MIN; break; case reduction::Type::MIN: mpi_reduce_type = MPI_MIN; break;
case reduction::Type_t::max: mpi_reduce_type = MPI_MAX; break; case reduction::Type::MAX: mpi_reduce_type = MPI_MAX; break;
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
......
...@@ -21,13 +21,8 @@ using namespace ngraph; ...@@ -21,13 +21,8 @@ using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"}; const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce() op::AllReduce::AllReduce(const Output<Node>& arg, reduction::Type reduce_type)
: m_reduce_type(reduction::sum) : Op({arg})
{
}
op::AllReduce::AllReduce(const shared_ptr<Node>& arg, const reduction::Type reduce_type)
: Op(check_single_output_args({arg}))
, m_reduce_type(reduce_type) , m_reduce_type(reduce_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const ...@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const
{ {
return m_reduce_type; return m_reduce_type;
} }
void op::AllReduce::set_reduce_type(reduction::Type reduce_type)
{
m_reduce_type = reduce_type;
}
...@@ -29,17 +29,17 @@ namespace ngraph ...@@ -29,17 +29,17 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AllReduce(); AllReduce() = default;
AllReduce(const std::shared_ptr<Node>& arg, AllReduce(const Output<Node>& arg, reduction::Type reduce_type = reduction::Type::SUM);
const reduction::Type reduce_type = reduction::sum);
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const; reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type);
private: private:
const reduction::Type m_reduce_type; reduction::Type m_reduce_type{reduction::Type::SUM};
}; };
} }
} }
...@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type) ...@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type)
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (reduce_type.get_type()) switch (reduce_type)
{ {
case reduction::Type_t::sum: case reduction::Type::SUM:
copy_data(a, v); copy_data(a, v);
std::transform( std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size)); v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
break; break;
case reduction::Type_t::prod: case reduction::Type::PROD:
copy_data(a, v); copy_data(a, v);
std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float { std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float {
return pow(elm, comm_size); return pow(elm, comm_size);
}); });
break; break;
case reduction::Type_t::min: case reduction::Type::MIN:
case reduction::Type_t::max: case reduction::Type::MAX:
auto shift = get_distributed_interface()->get_rank(); auto shift = get_distributed_interface()->get_rank();
std::rotate(v.begin(), v.begin() + shift % v.size(), v.end()); std::rotate(v.begin(), v.begin() + shift % v.size(), v.end());
copy_data(a, v); copy_data(a, v);
if (reduce_type == reduction::Type_t::min) if (reduce_type == reduction::Type::MIN)
{ {
std::fill(v.begin(), v.end(), 1); std::fill(v.begin(), v.end(), 1);
for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++) for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++)
...@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type) ...@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type)
TEST(distributed_${BACKEND_NAME}, allreduce_sum) TEST(distributed_${BACKEND_NAME}, allreduce_sum)
{ {
test_allreduce_common(reduction::sum); test_allreduce_common(reduction::Type::SUM);
} }
TEST(distributed_${BACKEND_NAME}, allreduce_min) TEST(distributed_${BACKEND_NAME}, allreduce_min)
{ {
test_allreduce_common(reduction::min); test_allreduce_common(reduction::Type::MIN);
} }
TEST(distributed_${BACKEND_NAME}, allreduce_max) TEST(distributed_${BACKEND_NAME}, allreduce_max)
{ {
test_allreduce_common(reduction::max); test_allreduce_common(reduction::Type::MAX);
} }
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE) #if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod) TEST(distributed_${BACKEND_NAME}, allreduce_prod)
{ {
test_allreduce_common(reduction::prod); test_allreduce_common(reduction::Type::PROD);
} }
#endif #endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment