Commit b18eb8cc authored by fenglei's avatar fenglei

support more data type

parent 64fe235f
...@@ -140,14 +140,14 @@ namespace ngraph ...@@ -140,14 +140,14 @@ namespace ngraph
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{ {
auto data_type = MPI_FLOAT; auto data_type = MPI_FLOAT;
// for send/recv bf16 and f16 can be treat as MPI_SHORT since all are 16bits
if (element_type == element::Type_t::f64) if (element_type == element::Type_t::bf16 || element_type == element::Type_t::f16)
{ {
data_type = MPI_DOUBLE; data_type = MPI_SHORT;
} }
else if (element_type != element::Type_t::f32) else
{ {
throw std::runtime_error("recv op supports only f32 and f64 types"); data_type = ngraph_type_to_mpi_type(element_type);
} }
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
...@@ -159,20 +159,52 @@ namespace ngraph ...@@ -159,20 +159,52 @@ namespace ngraph
int dest_id) override int dest_id) override
{ {
auto data_type = MPI_FLOAT; auto data_type = MPI_FLOAT;
// for send/recv bf16 and f16 can be treat as MPI_SHORT since all are 16bits
if (element_type == element::Type_t::f64) if (element_type == element::Type_t::bf16 || element_type == element::Type_t::f16)
{ {
data_type = MPI_DOUBLE; data_type = MPI_SHORT;
} }
else if (element_type != element::Type_t::f32) else
{ {
throw std::runtime_error("send op supports only f32 and f64 types"); data_type = ngraph_type_to_mpi_type(element_type);
} }
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD); MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
} }
protected: protected:
MPI_Datatype ngraph_type_to_mpi_type(element::Type_t& n_type)
{
MPI_Datatype m_type = MPI_FLOAT;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (n_type)
{
case element::Type_t::boolean: m_type = MPI_BYTE; break;
case element::Type_t::f32: m_type = MPI_FLOAT; break;
case element::Type_t::f64: m_type = MPI_DOUBLE; break;
case element::Type_t::i8: m_type = MPI_BYTE; break;
case element::Type_t::i16: m_type = MPI_SHORT; break;
case element::Type_t::i32: m_type = MPI_INT; break;
case element::Type_t::i64: m_type = MPI_LONG; break;
case element::Type_t::u8: m_type = MPI_UNSIGNED_CHAR; break;
case element::Type_t::u16: m_type = MPI_UNSIGNED_SHORT; break;
case element::Type_t::u32: m_type = MPI_UNSIGNED; break;
case element::Type_t::u64: m_type = MPI_UNSIGNED_LONG; break;
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
return m_type;
}
std::string m_name; std::string m_name;
bool m_initialized_mpi = false; bool m_initialized_mpi = false;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment