Commit d0a83a35 authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Eliminate wrapped enum (#3111)

* Eliminate wrapped enum
Switch allreduce to new op form

* review comments

* review comments

* typo
parent 8a062c4a
......@@ -22,11 +22,6 @@
using namespace ngraph;
NGRAPH_API const reduction::Type reduction::sum(reduction::Type_t::sum);
NGRAPH_API const reduction::Type reduction::prod(reduction::Type_t::prod);
NGRAPH_API const reduction::Type reduction::min(reduction::Type_t::min);
NGRAPH_API const reduction::Type reduction::max(reduction::Type_t::max);
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
......@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (obj.get_type())
switch (obj)
{
case reduction::Type_t::sum: out << "sum"; break;
case reduction::Type_t::prod: out << "prod"; break;
case reduction::Type_t::min: out << "min"; break;
case reduction::Type_t::max: out << "max"; break;
case reduction::Type::SUM: out << "SUM"; break;
case reduction::Type::PROD: out << "PROD"; break;
case reduction::Type::MIN: out << "MIN"; break;
case reduction::Type::MAX: out << "MAX"; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
return out;
};
bool reduction::Type::operator==(const reduction::Type& other) const
{
return m_type == other.m_type;
}
reduction::Type_t reduction::Type::get_type() const
{
return m_type;
}
static std::unique_ptr<DistributedInterface> s_distributed_interface;
void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface)
......
......@@ -26,34 +26,15 @@ namespace ngraph
{
namespace reduction
{
enum class Type_t
enum class Type
{
sum,
prod,
min,
max,
SUM,
PROD,
MIN,
MAX,
};
class Type
{
public:
Type(const Type_t t)
: m_type(t)
{
}
friend std::ostream& operator<<(std::ostream&, const Type&);
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
Type_t get_type() const;
private:
Type_t m_type;
};
std::ostream& operator<<(std::ostream& out, const Type& obj);
extern NGRAPH_API const Type sum;
extern NGRAPH_API const Type prod;
extern NGRAPH_API const Type min;
extern NGRAPH_API const Type max;
}
class DistributedInterface
......
......@@ -92,14 +92,14 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
switch (reduce_type)
{
case reduction::Type_t::sum: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type_t::prod:
case reduction::Type::SUM: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type::PROD:
throw std::runtime_error("MLSL doesn't support allreduce prod");
break;
case reduction::Type_t::min: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type_t::max: mlsl_reduce_type = MLSL::RT_MAX; break;
case reduction::Type::MIN: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type::MAX: mlsl_reduce_type = MLSL::RT_MAX; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......
......@@ -104,12 +104,12 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
switch (reduce_type)
{
case reduction::Type_t::sum: mpi_reduce_type = MPI_SUM; break;
case reduction::Type_t::prod: mpi_reduce_type = MPI_PROD; break;
case reduction::Type_t::min: mpi_reduce_type = MPI_MIN; break;
case reduction::Type_t::max: mpi_reduce_type = MPI_MAX; break;
case reduction::Type::SUM: mpi_reduce_type = MPI_SUM; break;
case reduction::Type::PROD: mpi_reduce_type = MPI_PROD; break;
case reduction::Type::MIN: mpi_reduce_type = MPI_MIN; break;
case reduction::Type::MAX: mpi_reduce_type = MPI_MAX; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......
......@@ -21,13 +21,8 @@ using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce()
: m_reduce_type(reduction::sum)
{
}
op::AllReduce::AllReduce(const shared_ptr<Node>& arg, const reduction::Type reduce_type)
: Op(check_single_output_args({arg}))
op::AllReduce::AllReduce(const Output<Node>& arg, reduction::Type reduce_type)
: Op({arg})
, m_reduce_type(reduce_type)
{
constructor_validate_and_infer_types();
......@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const
{
return m_reduce_type;
}
void op::AllReduce::set_reduce_type(reduction::Type reduce_type)
{
m_reduce_type = reduce_type;
}
......@@ -29,17 +29,17 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
AllReduce();
AllReduce(const std::shared_ptr<Node>& arg,
const reduction::Type reduce_type = reduction::sum);
AllReduce() = default;
AllReduce(const Output<Node>& arg, reduction::Type reduce_type = reduction::Type::SUM);
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type);
private:
const reduction::Type m_reduce_type;
reduction::Type m_reduce_type{reduction::Type::SUM};
};
}
}
......@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type)
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
switch (reduce_type)
{
case reduction::Type_t::sum:
case reduction::Type::SUM:
copy_data(a, v);
std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
break;
case reduction::Type_t::prod:
case reduction::Type::PROD:
copy_data(a, v);
std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float {
return pow(elm, comm_size);
});
break;
case reduction::Type_t::min:
case reduction::Type_t::max:
case reduction::Type::MIN:
case reduction::Type::MAX:
auto shift = get_distributed_interface()->get_rank();
std::rotate(v.begin(), v.begin() + shift % v.size(), v.end());
copy_data(a, v);
if (reduce_type == reduction::Type_t::min)
if (reduce_type == reduction::Type::MIN)
{
std::fill(v.begin(), v.end(), 1);
for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++)
......@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type)
TEST(distributed_${BACKEND_NAME}, allreduce_sum)
{
test_allreduce_common(reduction::sum);
test_allreduce_common(reduction::Type::SUM);
}
TEST(distributed_${BACKEND_NAME}, allreduce_min)
{
test_allreduce_common(reduction::min);
test_allreduce_common(reduction::Type::MIN);
}
TEST(distributed_${BACKEND_NAME}, allreduce_max)
{
test_allreduce_common(reduction::max);
test_allreduce_common(reduction::Type::MAX);
}
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod)
{
test_allreduce_common(reduction::prod);
test_allreduce_common(reduction::Type::PROD);
}
#endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment