Unverified Commit e8e3db24 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add send recv op (#3107)

* add send recv file

* update CMakeLists

* add to op_tbl, add to serializer

* fix bug in serializer, add MPI send recv

* not supported in MlSL

* GPU not support send recv

* add reference implementation

* send recv not supported by intel gpu

* resolve Scott's comment about construction

* style

* add comments, add to gpu not supported list

* using Output<Node> instead of shared_ptr<Node>

* add test

* disable test for MLSL

* disable test for cpu, gpu, gpuh

* add static string s_manifest = ;

* using NGRAPH_TEST so test can be disable for specific backend

* float number

* revert last change

* default

* support more data type

* change License text

* change License text

* add and ring send recv test

* add and ring send recv test

* skip send_recv_ring

* fix bug
parents bd4a1050 b3d99f67
...@@ -250,6 +250,8 @@ set (SRC ...@@ -250,6 +250,8 @@ set (SRC
op/product.hpp op/product.hpp
op/quantize.cpp op/quantize.cpp
op/quantize.hpp op/quantize.hpp
op/recv.cpp
op/recv.hpp
op/relu.cpp op/relu.cpp
op/relu.hpp op/relu.hpp
op/replace_slice.cpp op/replace_slice.cpp
...@@ -268,6 +270,8 @@ set (SRC ...@@ -268,6 +270,8 @@ set (SRC
op/scatter_nd_add.hpp op/scatter_nd_add.hpp
op/select.cpp op/select.cpp
op/select.hpp op/select.hpp
op/send.cpp
op/send.hpp
op/sigmoid.cpp op/sigmoid.cpp
op/sigmoid.hpp op/sigmoid.hpp
op/sign.cpp op/sign.cpp
......
...@@ -53,6 +53,9 @@ namespace ngraph ...@@ -53,6 +53,9 @@ namespace ngraph
size_t count) = 0; size_t count) = 0;
virtual void virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0; broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
virtual void recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
virtual void
send(const void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
}; };
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface); void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
...@@ -138,6 +138,19 @@ namespace ngraph ...@@ -138,6 +138,19 @@ namespace ngraph
env.DeleteDistribution(distribution); env.DeleteDistribution(distribution);
} }
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
throw ngraph_error("recv not supported/mentioned in MLSL");
}
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("send not supported/mentioned in MLSL");
}
protected: protected:
std::string m_name{"MLSL"}; std::string m_name{"MLSL"};
bool m_initialized_mlsl = false; bool m_initialized_mlsl = false;
......
...@@ -52,6 +52,19 @@ namespace ngraph ...@@ -52,6 +52,19 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned"); throw ngraph_error("Distributed Library not supported/mentioned");
} }
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
protected: protected:
std::string m_name{"NULL"}; std::string m_name{"NULL"};
}; };
......
...@@ -137,7 +137,74 @@ namespace ngraph ...@@ -137,7 +137,74 @@ namespace ngraph
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD); MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
} }
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
auto data_type = MPI_FLOAT;
// for send/recv bf16 and f16 can be treat as MPI_SHORT since all are 16bits
if (element_type == element::Type_t::bf16 || element_type == element::Type_t::f16)
{
data_type = MPI_SHORT;
}
else
{
data_type = ngraph_type_to_mpi_type(element_type);
}
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
auto data_type = MPI_FLOAT;
// for send/recv bf16 and f16 can be treat as MPI_SHORT since all are 16bits
if (element_type == element::Type_t::bf16 || element_type == element::Type_t::f16)
{
data_type = MPI_SHORT;
}
else
{
data_type = ngraph_type_to_mpi_type(element_type);
}
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
}
protected: protected:
MPI_Datatype ngraph_type_to_mpi_type(element::Type_t& n_type)
{
MPI_Datatype m_type = MPI_FLOAT;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (n_type)
{
case element::Type_t::boolean: m_type = MPI_BYTE; break;
case element::Type_t::f32: m_type = MPI_FLOAT; break;
case element::Type_t::f64: m_type = MPI_DOUBLE; break;
case element::Type_t::i8: m_type = MPI_BYTE; break;
case element::Type_t::i16: m_type = MPI_SHORT; break;
case element::Type_t::i32: m_type = MPI_INT; break;
case element::Type_t::i64: m_type = MPI_LONG; break;
case element::Type_t::u8: m_type = MPI_UNSIGNED_CHAR; break;
case element::Type_t::u16: m_type = MPI_UNSIGNED_SHORT; break;
case element::Type_t::u32: m_type = MPI_UNSIGNED; break;
case element::Type_t::u64: m_type = MPI_UNSIGNED_LONG; break;
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
return m_type;
}
std::string m_name; std::string m_name;
bool m_initialized_mpi = false; bool m_initialized_mpi = false;
}; };
......
...@@ -144,6 +144,7 @@ ...@@ -144,6 +144,7 @@
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -152,6 +153,7 @@ ...@@ -152,6 +153,7 @@
#include "ngraph/op/scatter_add.hpp" #include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp" #include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp" #include "ngraph/op/sin.hpp"
......
...@@ -128,6 +128,7 @@ NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op) ...@@ -128,6 +128,7 @@ NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op) NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op) NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op) NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(Range, ngraph::op) NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op) NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op) NGRAPH_OP(ReluBackprop, ngraph::op)
...@@ -140,6 +141,7 @@ NGRAPH_OP(ScalarConstantLike, ngraph::op) ...@@ -140,6 +141,7 @@ NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op) NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op) NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op) NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Send, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op) NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op) NGRAPH_OP(Sigmoid, ngraph::op)
NGRAPH_OP(SigmoidBackprop, ngraph::op) NGRAPH_OP(SigmoidBackprop, ngraph::op)
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/recv.hpp"
using namespace std;
using namespace ngraph;
const string op::Recv::type_name{"Recv"};
op::Recv::Recv(const Output<Node>& arg, int src_id)
: Op({arg})
, m_src_id(src_id)
{
constructor_validate_and_infer_types();
}
void op::Recv::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64,
"Only element types f32 and f64 are supported (argument element type: ",
get_input_element_type(0),
").");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::Recv::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Recv>(new_args.at(0), m_src_id);
}
int op::Recv::get_src_id() const
{
return m_src_id;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class Recv : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized recv operation.
Recv() = default;
/// \brief Constructs a Recv operation.
///
/// \param arg The node for tensor to receive data
/// \param src_id the source id which could be rank or node id.
Recv(const Output<Node>& arg, int src_id);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_src_id() const;
private:
const int m_src_id;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/send.hpp"
using namespace std;
using namespace ngraph;
const string op::Send::type_name{"Send"};
op::Send::Send(const Output<Node>& arg, int dest_id)
: Op({arg})
, m_dest_id(dest_id)
{
constructor_validate_and_infer_types();
}
void op::Send::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64,
"Only element types f32 and f64 are supported (argument element type: ",
get_input_element_type(0),
").");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::Send::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Send>(new_args.at(0), m_dest_id);
}
int op::Send::get_dest_id() const
{
return m_dest_id;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class Send : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized send operation.
Send() = default;
/// \brief Constructs a send operation.
///
/// \param arg The node for input tensor
/// \param dest_id the target id which could be rank of node id.
Send(const Output<Node>& arg, int dest_id);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_dest_id() const;
private:
const int m_dest_id;
};
}
}
...@@ -10,4 +10,5 @@ max_3d_to_scalar_int32 ...@@ -10,4 +10,5 @@ max_3d_to_scalar_int32
# Not implemented # Not implemented
erf erf
zero_sized_erf zero_sized_erf
send_recv
send_recv_ring
...@@ -232,7 +232,9 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const ...@@ -232,7 +232,9 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
"GenerateMask", "GenerateMask",
"DynBroadcast", "DynBroadcast",
"Transpose", "Transpose",
"Range"}; "Range",
"Recv",
"Send"};
set<string> float_only = {"MaxPoolBackprop", "AvgPoolBackprop", "MaxPool", "Dot"}; set<string> float_only = {"MaxPoolBackprop", "AvgPoolBackprop", "MaxPool", "Dot"};
......
...@@ -104,6 +104,7 @@ ...@@ -104,6 +104,7 @@
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -113,6 +114,7 @@ ...@@ -113,6 +114,7 @@
#include "ngraph/op/scatter_add.hpp" #include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp" #include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp" #include "ngraph/op/sin.hpp"
...@@ -995,6 +997,11 @@ std::string runtime::gpu::GPU_Emitter::emit_QuantizedMaxPool(EMIT_ARGS) ...@@ -995,6 +997,11 @@ std::string runtime::gpu::GPU_Emitter::emit_QuantizedMaxPool(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'"); throw unsupported_op("Unsupported op '" + node->description() + "'");
} }
std::string runtime::gpu::GPU_Emitter::emit_Recv(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Range(EMIT_ARGS) std::string runtime::gpu::GPU_Emitter::emit_Range(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); throw unsupported_op("Unsupported op '" + node->description() + "'");
...@@ -1250,6 +1257,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS) ...@@ -1250,6 +1257,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
return emit_elementwise<ngraph::op::Select>(compiled_function, function_name, node, args, out); return emit_elementwise<ngraph::op::Select>(compiled_function, function_name, node, args, out);
} }
std::string runtime::gpu::GPU_Emitter::emit_Send(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_ShapeOf(EMIT_ARGS) std::string runtime::gpu::GPU_Emitter::emit_ShapeOf(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); throw unsupported_op("Unsupported op '" + node->description() + "'");
......
...@@ -205,3 +205,5 @@ gather_no_axis_bool ...@@ -205,3 +205,5 @@ gather_no_axis_bool
fake_quantize fake_quantize
fake_quantize_with_clip fake_quantize_with_clip
fake_quantize_with_clip_across_channels fake_quantize_with_clip_across_channels
send_recv
send_recv_ring
...@@ -2,3 +2,5 @@ computation_reuse ...@@ -2,3 +2,5 @@ computation_reuse
tensorview_custom_mem tensorview_custom_mem
batch_norm_inference_f64 batch_norm_inference_f64
batch_norm_inference_f32 batch_norm_inference_f32
send_recv
send_recv_ring
...@@ -2086,12 +2086,14 @@ shared_ptr<runtime::Executable> ...@@ -2086,12 +2086,14 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::QuantizedDot: case OP_TYPEID::QuantizedDot:
case OP_TYPEID::QuantizedDotBias: case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedMaxPool: case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::Recv:
case OP_TYPEID::Range: case OP_TYPEID::Range:
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::ScalarConstantLike: case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift: case OP_TYPEID::ScaleShift:
case OP_TYPEID::ScatterAdd: case OP_TYPEID::ScatterAdd:
case OP_TYPEID::ScatterNDAdd: case OP_TYPEID::ScatterNDAdd:
case OP_TYPEID::Send:
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
case OP_TYPEID::ShuffleChannels: case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth: case OP_TYPEID::SpaceToDepth:
......
...@@ -106,6 +106,8 @@ gather_no_axis_bool ...@@ -106,6 +106,8 @@ gather_no_axis_bool
fake_quantize fake_quantize
fake_quantize_with_clip fake_quantize_with_clip
fake_quantize_with_clip_across_channels fake_quantize_with_clip_across_channels
send_recv
send_recv_ring
# Not supported quant ops # Not supported quant ops
model_dequantize_linear_1d_zero_scale_int8 model_dequantize_linear_1d_zero_scale_int8
......
...@@ -55,11 +55,13 @@ ...@@ -55,11 +55,13 @@
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -127,6 +129,7 @@ ...@@ -127,6 +129,7 @@
#include "ngraph/runtime/reference/power.hpp" #include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp" #include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
...@@ -136,6 +139,7 @@ ...@@ -136,6 +139,7 @@
#include "ngraph/runtime/reference/scatter_add.hpp" #include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp" #include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp" #include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
...@@ -1178,6 +1182,21 @@ private: ...@@ -1178,6 +1182,21 @@ private:
throw unsupported_op("Unsupported op '" + node.description() + throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end."); "' in Interpreter back end.");
} }
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Range: case OP_TYPEID::Range:
{ {
throw unsupported_op("Unsupported op '" + node.description() + "'"); throw unsupported_op("Unsupported op '" + node.description() + "'");
...@@ -1329,6 +1348,21 @@ private: ...@@ -1329,6 +1348,21 @@ private:
element_count); element_count);
break; break;
} }
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
{ {
reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>()); reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void recv(T* arg, const element::Type_t element_type, size_t count, int src_id)
{
get_distributed_interface()->recv(arg, element_type, count, src_id);
}
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void send(const T* arg, const element::Type_t element_type, size_t count, int dest_id)
{
get_distributed_interface()->send(arg, element_type, count, dest_id);
}
}
}
}
...@@ -118,6 +118,7 @@ ...@@ -118,6 +118,7 @@
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -127,6 +128,7 @@ ...@@ -127,6 +128,7 @@
#include "ngraph/op/scatter_add.hpp" #include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp" #include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp" #include "ngraph/op/sin.hpp"
...@@ -1605,6 +1607,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1605,6 +1607,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break; break;
} }
case OP_TYPEID::Recv:
{
auto src_id = node_js.at("source_id").get<size_t>();
node = make_shared<op::Recv>(args[0], src_id);
break;
}
case OP_TYPEID::Range: case OP_TYPEID::Range:
{ {
node = make_shared<op::Range>(args[0], args[1], args[2]); node = make_shared<op::Range>(args[0], args[1], args[2]);
...@@ -1682,6 +1690,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1682,6 +1690,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Select>(args[0], args[1], args[2]); node = make_shared<op::Select>(args[0], args[1], args[2]);
break; break;
} }
case OP_TYPEID::Send:
{
auto dest_id = node_js.at("dest_id").get<size_t>();
node = make_shared<op::Send>(args[0], dest_id);
break;
}
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
{ {
node = make_shared<op::ShapeOf>(args[0]); node = make_shared<op::ShapeOf>(args[0]);
...@@ -2602,6 +2616,12 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2602,6 +2616,12 @@ json JSONSerializer::serialize_node(const Node& n)
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
break; break;
} }
case OP_TYPEID::Recv:
{
auto tmp = dynamic_cast<const op::Recv*>(&n);
node["source_id"] = tmp->get_src_id();
break;
}
case OP_TYPEID::Range: { break; case OP_TYPEID::Range: { break;
} }
case OP_TYPEID::Relu: { break; case OP_TYPEID::Relu: { break;
...@@ -2658,6 +2678,12 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2658,6 +2678,12 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::Select: { break; case OP_TYPEID::Select: { break;
} }
case OP_TYPEID::Send:
{
auto tmp = dynamic_cast<const op::Send*>(&n);
node["dest_id"] = tmp->get_dest_id();
break;
}
case OP_TYPEID::ShapeOf: { break; case OP_TYPEID::ShapeOf: { break;
} }
case OP_TYPEID::ShuffleChannels: case OP_TYPEID::ShuffleChannels:
......
...@@ -25,10 +25,13 @@ ...@@ -25,10 +25,13 @@
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
#include "util/random.hpp" #include "util/random.hpp"
#include "util/test_control.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
static string s_manifest = "${MANIFEST}";
static void test_allreduce_common(reduction::Type reduce_type) static void test_allreduce_common(reduction::Type reduce_type)
{ {
auto comm_size = get_distributed_interface()->get_size(); auto comm_size = get_distributed_interface()->get_size();
...@@ -91,29 +94,29 @@ static void test_allreduce_common(reduction::Type reduce_type) ...@@ -91,29 +94,29 @@ static void test_allreduce_common(reduction::Type reduce_type)
} }
} }
TEST(distributed_${BACKEND_NAME}, allreduce_sum) NGRAPH_TEST(${BACKEND_NAME}, allreduce_sum)
{ {
test_allreduce_common(reduction::Type::SUM); test_allreduce_common(reduction::Type::SUM);
} }
TEST(distributed_${BACKEND_NAME}, allreduce_min) NGRAPH_TEST(${BACKEND_NAME}, allreduce_min)
{ {
test_allreduce_common(reduction::Type::MIN); test_allreduce_common(reduction::Type::MIN);
} }
TEST(distributed_${BACKEND_NAME}, allreduce_max) NGRAPH_TEST(${BACKEND_NAME}, allreduce_max)
{ {
test_allreduce_common(reduction::Type::MAX); test_allreduce_common(reduction::Type::MAX);
} }
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE) #if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod) NGRAPH_TEST(${BACKEND_NAME}, allreduce_prod)
{ {
test_allreduce_common(reduction::Type::PROD); test_allreduce_common(reduction::Type::PROD);
} }
#endif #endif
TEST(distributed_${BACKEND_NAME}, broadcastdistributed) NGRAPH_TEST(${BACKEND_NAME}, broadcastdistributed)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
...@@ -140,3 +143,90 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed) ...@@ -140,3 +143,90 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
EXPECT_EQ(v, read_vector<float>(result)); EXPECT_EQ(v, read_vector<float>(result));
} }
} }
//MLSL does not support send recv
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
NGRAPH_TEST(${BACKEND_NAME}, send_recv)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto comm_size = get_distributed_interface()->get_size();
// this test only works for 2 nodes
if (comm_size != 2)
{
return;
}
auto rank = get_distributed_interface()->get_rank();
std::shared_ptr<Function> f;
if (rank == 0)
{
f = make_shared<Function>(make_shared<op::Send>(A, 1), ParameterVector{A});
}
else
{
f = make_shared<Function>(make_shared<op::Recv>(A, 0), ParameterVector{A});
}
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4};
auto result = backend->create_tensor(element::f32, shape);
copy_data(result, vector<float>(4, 0));
if (rank == 0)
{
copy_data(result, v);
}
auto handle = backend->compile(f);
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
#endif
//MLSL does not support send recv
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
NGRAPH_TEST(${BACKEND_NAME}, send_recv_ring)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto comm_size = get_distributed_interface()->get_size();
// test only works for at least 2 nodes
if (comm_size < 2)
{
return;
}
auto rank = get_distributed_interface()->get_rank();
std::shared_ptr<Function> f_send;
std::shared_ptr<Function> f_recv;
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4};
auto result = backend->create_tensor(element::f32, shape);
copy_data(result, vector<float>(4, 0));
if (rank != 0)
{
f_recv = make_shared<Function>(make_shared<op::Recv>(A, rank - 1), ParameterVector{A});
auto handle = backend->compile(f_recv);
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
else
{
copy_data(result, v);
}
f_send =
make_shared<Function>(make_shared<op::Send>(A, (rank + 1) % comm_size), ParameterVector{A});
auto handle = backend->compile(f_send);
handle->call_with_validate({result}, {result});
if (rank == 0)
{
f_recv = make_shared<Function>(make_shared<op::Recv>(A, comm_size - 1), ParameterVector{A});
auto handle = backend->compile(f_recv);
copy_data(result, vector<float>(4, 0));
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
}
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment