Commit e76bc029 authored by fenglei's avatar fenglei

add reference implementation

parent 68d7e286
......@@ -72,10 +72,9 @@ namespace ngraph
size_t count) = 0;
virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
virtual void recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
virtual void
send(void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
virtual void
recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
send(const void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
};
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
......@@ -138,15 +138,12 @@ namespace ngraph
env.DeleteDistribution(distribution);
}
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
throw ngraph_error("recv not supported/mentioned in MLSL");
}
void send(void* in,
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
......
......@@ -52,21 +52,19 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned");
}
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
void send(void* in,
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
protected:
std::string m_name{"NULL"};
};
......
......@@ -137,10 +137,7 @@ namespace ngraph
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
}
void recv(void* in,
element::Type_t element_type,
size_t count,
int src_id) override
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
auto data_type = MPI_FLOAT;
......@@ -150,14 +147,13 @@ namespace ngraph
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"recv op supports only f32 and f64 types");
throw std::runtime_error("recv op supports only f32 and f64 types");
}
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
void send(void* in,
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
......@@ -170,8 +166,7 @@ namespace ngraph
}
else if (element_type != element::Type_t::f32)
{
throw std::runtime_error(
"send op supports only f32 and f64 types");
throw std::runtime_error("send op supports only f32 and f64 types");
}
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
......
......@@ -55,11 +55,13 @@
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
......@@ -127,6 +129,7 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
......@@ -136,6 +139,7 @@
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
......@@ -1178,6 +1182,21 @@ private:
throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
......@@ -1324,6 +1343,21 @@ private:
element_count);
break;
}
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf:
{
reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void recv(T* arg, const element::Type_t element_type, size_t count, int src_id)
{
get_distributed_interface()->recv(arg, element_type, count, src_id);
}
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void send(const T* arg, const element::Type_t element_type, size_t count, int dest_id)
{
get_distributed_interface()->send(arg, element_type, count, dest_id);
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment