Unverified Commit e8e3db24 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add send recv op (#3107)

* add send recv file

* update CMakeLists

* add to op_tbl, add to serializer

* fix bug in serializer, add MPI send recv

* not supported in MlSL

* GPU not support send recv

* add reference implementation

* send recv not supported by intel gpu

* resolve Scott's comment about construction

* style

* add comments, add to gpu not supported list

* using Output<Node> instead of shared_ptr<Node>

* add test

* disable test for MLSL

* disable test for cpu, gpu, gpuh

* add static string s_manifest = ;

* using NGRAPH_TEST so test can be disable for specific backend

* float number

* revert last change

* default

* support more data type

* change License text

* change License text

* add and ring send recv test

* add and ring send recv test

* skip send_recv_ring

* fix bug
parents bd4a1050 b3d99f67
......@@ -250,6 +250,8 @@ set (SRC
op/product.hpp
op/quantize.cpp
op/quantize.hpp
op/recv.cpp
op/recv.hpp
op/relu.cpp
op/relu.hpp
op/replace_slice.cpp
......@@ -268,6 +270,8 @@ set (SRC
op/scatter_nd_add.hpp
op/select.cpp
op/select.hpp
op/send.cpp
op/send.hpp
op/sigmoid.cpp
op/sigmoid.hpp
op/sign.cpp
......
......@@ -53,6 +53,9 @@ namespace ngraph
size_t count) = 0;
virtual void
broadcast(void* in, element::Type_t element_type, size_t count, int root_id) = 0;
virtual void recv(void* in, element::Type_t element_type, size_t count, int src_id) = 0;
virtual void
send(const void* in, element::Type_t element_type, size_t count, int dest_id) = 0;
};
void set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface);
......
......@@ -138,6 +138,19 @@ namespace ngraph
env.DeleteDistribution(distribution);
}
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
throw ngraph_error("recv not supported/mentioned in MLSL");
}
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("send not supported/mentioned in MLSL");
}
protected:
std::string m_name{"MLSL"};
bool m_initialized_mlsl = false;
......
......@@ -52,6 +52,19 @@ namespace ngraph
throw ngraph_error("Distributed Library not supported/mentioned");
}
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
throw ngraph_error("Distributed Library not supported/mentioned");
}
protected:
std::string m_name{"NULL"};
};
......
......@@ -137,7 +137,74 @@ namespace ngraph
MPI_Bcast(in, count, data_type, root_id, MPI_COMM_WORLD);
}
void recv(void* in, element::Type_t element_type, size_t count, int src_id) override
{
auto data_type = MPI_FLOAT;
// for send/recv bf16 and f16 can be treat as MPI_SHORT since all are 16bits
if (element_type == element::Type_t::bf16 || element_type == element::Type_t::f16)
{
data_type = MPI_SHORT;
}
else
{
data_type = ngraph_type_to_mpi_type(element_type);
}
MPI_Recv(in, count, data_type, src_id, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
void send(const void* in,
element::Type_t element_type,
size_t count,
int dest_id) override
{
auto data_type = MPI_FLOAT;
// for send/recv bf16 and f16 can be treat as MPI_SHORT since all are 16bits
if (element_type == element::Type_t::bf16 || element_type == element::Type_t::f16)
{
data_type = MPI_SHORT;
}
else
{
data_type = ngraph_type_to_mpi_type(element_type);
}
MPI_Send(in, count, data_type, dest_id, 0, MPI_COMM_WORLD);
}
protected:
MPI_Datatype ngraph_type_to_mpi_type(element::Type_t& n_type)
{
MPI_Datatype m_type = MPI_FLOAT;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (n_type)
{
case element::Type_t::boolean: m_type = MPI_BYTE; break;
case element::Type_t::f32: m_type = MPI_FLOAT; break;
case element::Type_t::f64: m_type = MPI_DOUBLE; break;
case element::Type_t::i8: m_type = MPI_BYTE; break;
case element::Type_t::i16: m_type = MPI_SHORT; break;
case element::Type_t::i32: m_type = MPI_INT; break;
case element::Type_t::i64: m_type = MPI_LONG; break;
case element::Type_t::u8: m_type = MPI_UNSIGNED_CHAR; break;
case element::Type_t::u16: m_type = MPI_UNSIGNED_SHORT; break;
case element::Type_t::u32: m_type = MPI_UNSIGNED; break;
case element::Type_t::u64: m_type = MPI_UNSIGNED_LONG; break;
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
return m_type;
}
std::string m_name;
bool m_initialized_mpi = false;
};
......
......@@ -144,6 +144,7 @@
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -152,6 +153,7 @@
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......
......@@ -128,6 +128,7 @@ NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
......@@ -140,6 +141,7 @@ NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Send, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
NGRAPH_OP(SigmoidBackprop, ngraph::op)
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/recv.hpp"
using namespace std;
using namespace ngraph;
const string op::Recv::type_name{"Recv"};
op::Recv::Recv(const Output<Node>& arg, int src_id)
: Op({arg})
, m_src_id(src_id)
{
constructor_validate_and_infer_types();
}
void op::Recv::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64,
"Only element types f32 and f64 are supported (argument element type: ",
get_input_element_type(0),
").");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::Recv::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Recv>(new_args.at(0), m_src_id);
}
int op::Recv::get_src_id() const
{
return m_src_id;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class Recv : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized recv operation.
Recv() = default;
/// \brief Constructs a Recv operation.
///
/// \param arg The node for tensor to receive data
/// \param src_id the source id which could be rank or node id.
Recv(const Output<Node>& arg, int src_id);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_src_id() const;
private:
const int m_src_id;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/send.hpp"
using namespace std;
using namespace ngraph;
const string op::Send::type_name{"Send"};
op::Send::Send(const Output<Node>& arg, int dest_id)
: Op({arg})
, m_dest_id(dest_id)
{
constructor_validate_and_infer_types();
}
void op::Send::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64,
"Only element types f32 and f64 are supported (argument element type: ",
get_input_element_type(0),
").");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::Send::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Send>(new_args.at(0), m_dest_id);
}
int op::Send::get_dest_id() const
{
return m_dest_id;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class Send : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized send operation.
Send() = default;
/// \brief Constructs a send operation.
///
/// \param arg The node for input tensor
/// \param dest_id the target id which could be rank of node id.
Send(const Output<Node>& arg, int dest_id);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_dest_id() const;
private:
const int m_dest_id;
};
}
}
......@@ -10,4 +10,5 @@ max_3d_to_scalar_int32
# Not implemented
erf
zero_sized_erf
send_recv
send_recv_ring
......@@ -232,7 +232,9 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
"GenerateMask",
"DynBroadcast",
"Transpose",
"Range"};
"Range",
"Recv",
"Send"};
set<string> float_only = {"MaxPoolBackprop", "AvgPoolBackprop", "MaxPool", "Dot"};
......
......@@ -104,6 +104,7 @@
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -113,6 +114,7 @@
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -995,6 +997,11 @@ std::string runtime::gpu::GPU_Emitter::emit_QuantizedMaxPool(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Recv(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Range(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
......@@ -1250,6 +1257,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
return emit_elementwise<ngraph::op::Select>(compiled_function, function_name, node, args, out);
}
std::string runtime::gpu::GPU_Emitter::emit_Send(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_ShapeOf(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
......
......@@ -205,3 +205,5 @@ gather_no_axis_bool
fake_quantize
fake_quantize_with_clip
fake_quantize_with_clip_across_channels
send_recv
send_recv_ring
......@@ -2,3 +2,5 @@ computation_reuse
tensorview_custom_mem
batch_norm_inference_f64
batch_norm_inference_f32
send_recv
send_recv_ring
......@@ -2086,12 +2086,14 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::QuantizedDot:
case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::Recv:
case OP_TYPEID::Range:
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::ScatterAdd:
case OP_TYPEID::ScatterNDAdd:
case OP_TYPEID::Send:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth:
......
......@@ -106,6 +106,8 @@ gather_no_axis_bool
fake_quantize
fake_quantize_with_clip
fake_quantize_with_clip_across_channels
send_recv
send_recv_ring
# Not supported quant ops
model_dequantize_linear_1d_zero_scale_int8
......
......@@ -55,11 +55,13 @@
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
......@@ -127,6 +129,7 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
......@@ -136,6 +139,7 @@
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
......@@ -1178,6 +1182,21 @@ private:
throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Range:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
......@@ -1329,6 +1348,21 @@ private:
element_count);
break;
}
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf:
{
reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void recv(T* arg, const element::Type_t element_type, size_t count, int src_id)
{
get_distributed_interface()->recv(arg, element_type, count, src_id);
}
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/distributed.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void send(const T* arg, const element::Type_t element_type, size_t count, int dest_id)
{
get_distributed_interface()->send(arg, element_type, count, dest_id);
}
}
}
}
......@@ -118,6 +118,7 @@
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -127,6 +128,7 @@
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -1605,6 +1607,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Recv:
{
auto src_id = node_js.at("source_id").get<size_t>();
node = make_shared<op::Recv>(args[0], src_id);
break;
}
case OP_TYPEID::Range:
{
node = make_shared<op::Range>(args[0], args[1], args[2]);
......@@ -1682,6 +1690,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Select>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Send:
{
auto dest_id = node_js.at("dest_id").get<size_t>();
node = make_shared<op::Send>(args[0], dest_id);
break;
}
case OP_TYPEID::ShapeOf:
{
node = make_shared<op::ShapeOf>(args[0]);
......@@ -2602,6 +2616,12 @@ json JSONSerializer::serialize_node(const Node& n)
node["padding_above"] = tmp->get_padding_above();
break;
}
case OP_TYPEID::Recv:
{
auto tmp = dynamic_cast<const op::Recv*>(&n);
node["source_id"] = tmp->get_src_id();
break;
}
case OP_TYPEID::Range: { break;
}
case OP_TYPEID::Relu: { break;
......@@ -2658,6 +2678,12 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Select: { break;
}
case OP_TYPEID::Send:
{
auto tmp = dynamic_cast<const op::Send*>(&n);
node["dest_id"] = tmp->get_dest_id();
break;
}
case OP_TYPEID::ShapeOf: { break;
}
case OP_TYPEID::ShuffleChannels:
......
......@@ -25,10 +25,13 @@
#include "ngraph/serializer.hpp"
#include "util/all_close_f.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
static void test_allreduce_common(reduction::Type reduce_type)
{
auto comm_size = get_distributed_interface()->get_size();
......@@ -91,29 +94,29 @@ static void test_allreduce_common(reduction::Type reduce_type)
}
}
TEST(distributed_${BACKEND_NAME}, allreduce_sum)
NGRAPH_TEST(${BACKEND_NAME}, allreduce_sum)
{
test_allreduce_common(reduction::Type::SUM);
}
TEST(distributed_${BACKEND_NAME}, allreduce_min)
NGRAPH_TEST(${BACKEND_NAME}, allreduce_min)
{
test_allreduce_common(reduction::Type::MIN);
}
TEST(distributed_${BACKEND_NAME}, allreduce_max)
NGRAPH_TEST(${BACKEND_NAME}, allreduce_max)
{
test_allreduce_common(reduction::Type::MAX);
}
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod)
NGRAPH_TEST(${BACKEND_NAME}, allreduce_prod)
{
test_allreduce_common(reduction::Type::PROD);
}
#endif
TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
NGRAPH_TEST(${BACKEND_NAME}, broadcastdistributed)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
......@@ -140,3 +143,90 @@ TEST(distributed_${BACKEND_NAME}, broadcastdistributed)
EXPECT_EQ(v, read_vector<float>(result));
}
}
//MLSL does not support send recv
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
NGRAPH_TEST(${BACKEND_NAME}, send_recv)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto comm_size = get_distributed_interface()->get_size();
// this test only works for 2 nodes
if (comm_size != 2)
{
return;
}
auto rank = get_distributed_interface()->get_rank();
std::shared_ptr<Function> f;
if (rank == 0)
{
f = make_shared<Function>(make_shared<op::Send>(A, 1), ParameterVector{A});
}
else
{
f = make_shared<Function>(make_shared<op::Recv>(A, 0), ParameterVector{A});
}
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4};
auto result = backend->create_tensor(element::f32, shape);
copy_data(result, vector<float>(4, 0));
if (rank == 0)
{
copy_data(result, v);
}
auto handle = backend->compile(f);
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
#endif
//MLSL does not support send recv
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
NGRAPH_TEST(${BACKEND_NAME}, send_recv_ring)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto comm_size = get_distributed_interface()->get_size();
// test only works for at least 2 nodes
if (comm_size < 2)
{
return;
}
auto rank = get_distributed_interface()->get_rank();
std::shared_ptr<Function> f_send;
std::shared_ptr<Function> f_recv;
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto v = vector<float>{1, 2, 3, 4};
auto result = backend->create_tensor(element::f32, shape);
copy_data(result, vector<float>(4, 0));
if (rank != 0)
{
f_recv = make_shared<Function>(make_shared<op::Recv>(A, rank - 1), ParameterVector{A});
auto handle = backend->compile(f_recv);
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
else
{
copy_data(result, v);
}
f_send =
make_shared<Function>(make_shared<op::Send>(A, (rank + 1) % comm_size), ParameterVector{A});
auto handle = backend->compile(f_send);
handle->call_with_validate({result}, {result});
if (rank == 0)
{
f_recv = make_shared<Function>(make_shared<op::Recv>(A, comm_size - 1), ParameterVector{A});
auto handle = backend->compile(f_recv);
copy_data(result, vector<float>(4, 0));
handle->call_with_validate({result}, {result});
EXPECT_EQ(v, read_vector<float>(result));
}
}
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment