Commit 2b4989ed authored by fenglei's avatar fenglei

fix bug in serializer, add MPI send recv

parent b06a4368
......@@ -72,6 +72,10 @@ namespace ngraph
size_t count) = 0;
virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
virtual void
send(void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
virtual void
recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
};
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
......@@ -52,6 +52,21 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned");
}
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
void send(void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
protected:
std::string m_name{"NULL"};
};
......
......@@ -137,6 +137,46 @@ namespace ngraph
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
}
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
{
auto data_type = MPI_FLOAT;
if (element_type == element::Type_t::f64)
{
data_type = MPI_DOUBLE;
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"recv op supports only f32 and f64 types");
}
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
void send(void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
auto data_type = MPI_FLOAT;
if (element_type == element::Type_t::f64)
{
data_type = MPI_DOUBLE;
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"send op supports only f32 and f64 types");
}
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
}
protected:
std::string m_name;
bool m_initialized_mpi = false;
......
......@@ -116,6 +116,7 @@
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -125,6 +126,7 @@
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -1559,8 +1561,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Recv:
{
auto src_id = node_js.at("source_id").get<vector<size_t>>();
node = make_shared<op::Relu>(args[0], src_id);
auto src_id = node_js.at("source_id").get<size_t>();
node = make_shared<op::Recv>(args[0], src_id);
break;
}
case OP_TYPEID::Relu:
......@@ -1637,8 +1639,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Send:
{
auto dest_id = node_js.at("dest_id").get<vector<size_t>>();
node = make_shared<op::Relu>(args[0], dest_id);
auto dest_id = node_js.at("dest_id").get<size_t>();
node = make_shared<op::Send>(args[0], dest_id);
break;
}
case OP_TYPEID::ShapeOf:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment