Commit e76bc029 authored by fenglei's avatar fenglei

add reference implementation

parent 68d7e286
...@@ -72,10 +72,9 @@ namespace ngraph ...@@ -72,10 +72,9 @@ namespace ngraph
size_t count) = 0; size_t count) = 0;
virtual void virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0; broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
virtual void recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
virtual void virtual void
send(void* in, element::Type_t element_type, size_t count, int dest_id) = 0; send(const void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
virtual void
recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
}; };
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface); void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
...@@ -138,18 +138,15 @@ namespace ngraph ...@@ -138,18 +138,15 @@ namespace ngraph
env.DeleteDistribution(distribution); env.DeleteDistribution(distribution);
} }
void recv(void* in, void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
element::Type_t element_type,
size_t count,
int src_id) override
{ {
throw ngraph_error("recv not supported/mentioned in MLSL"); throw ngraph_error("recv not supported/mentioned in MLSL");
} }
void send(void* in, void send(const void* in,
element::Type_t element_type, element::Type_t element_type,
size_t count, size_t count,
int dest_id) override int dest_id) override
{ {
throw ngraph_error("send not supported/mentioned in MLSL"); throw ngraph_error("send not supported/mentioned in MLSL");
} }
......
...@@ -52,21 +52,19 @@ namespace ngraph ...@@ -52,21 +52,19 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
void recv(void* in, void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
element::Type_t element_type,
size_t count,
int src_id) override
{ {
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
void send(void* in, void send(const void* in,
element::Type_t element_type, element::Type_t element_type,
size_t count, size_t count,
int dest_id) override int dest_id) override
{ {
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
protected: protected:
std::string m_name{"NULL"}; std::string m_name{"NULL"};
}; };
......
...@@ -137,10 +137,7 @@ namespace ngraph ...@@ -137,10 +137,7 @@ namespace ngraph
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD); MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
} }
void recv(void* in, void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
element::Type_t element_type,
size_t count,
int src_id) override
{ {
auto data_type = MPI_FLOAT; auto data_type = MPI_FLOAT;
...@@ -150,17 +147,16 @@ namespace ngraph ...@@ -150,17 +147,16 @@ namespace ngraph
} }
else if (element_type != element::Type_t::f32) else if (element_type != element::Type_t::f32)
{ {
throw std::runtime_error( throw std::runtime_error("recv op supports only f32 and f64 types");
"recv op supports only f32 and f64 types");
} }
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
} }
void send(void* in, void send(const void* in,
element::Type_t element_type, element::Type_t element_type,
size_t count, size_t count,
int dest_id) override int dest_id) override
{ {
auto data_type = MPI_FLOAT; auto data_type = MPI_FLOAT;
...@@ -170,8 +166,7 @@ namespace ngraph ...@@ -170,8 +166,7 @@ namespace ngraph
} }
else if (element_type != element::Type_t::f32) else if (element_type != element::Type_t::f32)
{ {
throw std::runtime_error( throw std::runtime_error("send op supports only f32 and f64 types");
"send op supports only f32 and f64 types");
} }
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD); MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
......
...@@ -55,11 +55,13 @@ ...@@ -55,11 +55,13 @@
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -127,6 +129,7 @@ ...@@ -127,6 +129,7 @@
#include "ngraph/runtime/reference/power.hpp" #include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp" #include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
...@@ -136,6 +139,7 @@ ...@@ -136,6 +139,7 @@
#include "ngraph/runtime/reference/scatter_add.hpp" #include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp" #include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp" #include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
...@@ -1178,6 +1182,21 @@ private: ...@@ -1178,6 +1182,21 @@ private:
throw unsupported_op("Unsupported op '" + node.description() + throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end."); "' in Interpreter back end.");
} }
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Relu: case OP_TYPEID::Relu:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
...@@ -1324,6 +1343,21 @@ private: ...@@ -1324,6 +1343,21 @@ private:
element_count); element_count);
break; break;
} }
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
{ {
reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>()); reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void recv(T* arg, const element::Type_t element_type, size_t count, int src_id)
{
get_distributed_interface()->recv(arg, element_type, count, src_id);
}
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void send(const T* arg, const element::Type_t element_type, size_t count, int dest_id)
{
get_distributed_interface()->send(arg, element_type, count, dest_id);
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment