Commit 2b4989ed authored by fenglei's avatar fenglei

fix bug in serializer, add MPI send recv

parent b06a4368
...@@ -72,6 +72,10 @@ namespace ngraph ...@@ -72,6 +72,10 @@ namespace ngraph
size_t count) = 0; size_t count) = 0;
virtual void virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0; broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
virtual void
send(void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
virtual void
recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
}; };
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface); void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
...@@ -52,6 +52,21 @@ namespace ngraph ...@@ -52,6 +52,21 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
void send(void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
protected: protected:
std::string m_name{"NULL"}; std::string m_name{"NULL"};
}; };
......
...@@ -137,6 +137,46 @@ namespace ngraph ...@@ -137,6 +137,46 @@ namespace ngraph
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD); MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
} }
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
{
auto data_type = MPI_FLOAT;
if (element_type == element::Type_t::f64)
{
data_type = MPI_DOUBLE;
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"recv op supports only f32 and f64 types");
}
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
void send(void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
auto data_type = MPI_FLOAT;
if (element_type == element::Type_t::f64)
{
data_type = MPI_DOUBLE;
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"send op supports only f32 and f64 types");
}
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
}
protected: protected:
std::string m_name; std::string m_name;
bool m_initialized_mpi = false; bool m_initialized_mpi = false;
......
...@@ -116,6 +116,7 @@ ...@@ -116,6 +116,7 @@
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -125,6 +126,7 @@ ...@@ -125,6 +126,7 @@
#include "ngraph/op/scatter_add.hpp" #include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp" #include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp" #include "ngraph/op/sin.hpp"
...@@ -1559,8 +1561,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1559,8 +1561,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Recv: case OP_TYPEID::Recv:
{ {
auto src_id = node_js.at("source_id").get<vector<size_t>>(); auto src_id = node_js.at("source_id").get<size_t>();
node = make_shared<op::Relu>(args[0], src_id); node = make_shared<op::Recv>(args[0], src_id);
break; break;
} }
case OP_TYPEID::Relu: case OP_TYPEID::Relu:
...@@ -1637,8 +1639,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1637,8 +1639,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Send: case OP_TYPEID::Send:
{ {
auto dest_id = node_js.at("dest_id").get<vector<size_t>>(); auto dest_id = node_js.at("dest_id").get<size_t>();
node = make_shared<op::Relu>(args[0], dest_id); node = make_shared<op::Send>(args[0], dest_id);
break; break;
} }
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment