diff --git a/src/brpc/global.cpp b/src/brpc/global.cpp index cda4167a063bdd090d4f7079422a2e4331e01f60..63b9fd62a1785de7c3493afa235cb037e8a99511 100644 --- a/src/brpc/global.cpp +++ b/src/brpc/global.cpp @@ -31,6 +31,7 @@ // Load Balancers #include "brpc/policy/round_robin_load_balancer.h" +#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "brpc/policy/randomized_load_balancer.h" #include "brpc/policy/locality_aware_load_balancer.h" #include "brpc/policy/consistent_hashing_load_balancer.h" @@ -106,6 +107,7 @@ struct GlobalExtensions { RemoteFileNamingService rfns; RoundRobinLoadBalancer rr_lb; + WeightedRoundRobinLoadBalancer wrr_lb; RandomizedLoadBalancer randomized_lb; LocalityAwareLoadBalancer la_lb; ConsistentHashingLoadBalancer ch_mh_lb; @@ -318,6 +320,7 @@ static void GlobalInitializeOrDieImpl() { // Load Balancers LoadBalancerExtension()->RegisterOrDie("rr", &g_ext->rr_lb); + LoadBalancerExtension()->RegisterOrDie("wrr", &g_ext->wrr_lb); LoadBalancerExtension()->RegisterOrDie("random", &g_ext->randomized_lb); LoadBalancerExtension()->RegisterOrDie("la", &g_ext->la_lb); LoadBalancerExtension()->RegisterOrDie("c_murmurhash", &g_ext->ch_mh_lb); diff --git a/src/brpc/policy/weighted_round_robin_load_balancer.cpp b/src/brpc/policy/weighted_round_robin_load_balancer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34c5aa6e5099f50769bf891de0ee0c7db03cd624 --- /dev/null +++ b/src/brpc/policy/weighted_round_robin_load_balancer.cpp @@ -0,0 +1,244 @@ +// Copyright (c) 2018 Iqiyi, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Authors: Daojin Cai (caidaojin@qiyi.com) + +#include <algorithm> + +#include "butil/fast_rand.h" +#include "brpc/socket.h" +#include "brpc/policy/weighted_round_robin_load_balancer.h" +#include "butil/strings/string_number_conversions.h" + +namespace { + +const std::vector<uint64_t> prime_stride = { +2,3,5,11,17,29,47,71,107,137,163,251,307,379,569,683,857,1289,1543,1949,2617, +2927,3407,4391,6599,9901,14867,22303,33457,50207,75323,112997,169501,254257, +381389,572087,849083,1273637,1910471,2865727,4298629,6447943,9671923,14507903, +21761863,32642861,48964297,73446469,110169743,165254623,247881989,371822987, +557734537,836601847,1254902827,1882354259,2823531397,4235297173,6352945771, +9529418671}; + +bool IsCoprime(uint64_t num1, uint64_t num2) { + uint64_t temp; + if (num1 < num2) { + temp = num1; + num1 = num2; + num2 = temp; + } + while (true) { + temp = num1 % num2; + if (temp == 0) { + break; + } else { + num1 = num2; + num2 = temp; + } + } + return num2 == 1; +} + +// Get a reasonable stride according to weights configured of servers. +uint64_t GetStride(const uint64_t weight_sum, const size_t num) { + if (weight_sum == 1) { + return 1; + } + uint32_t average_weight = weight_sum / num; + auto iter = std::lower_bound(prime_stride.begin(), prime_stride.end(), + average_weight); + while (iter != prime_stride.end() + && !IsCoprime(weight_sum, *iter)) { + ++iter; + } + CHECK(iter != prime_stride.end()) << "Failed to get stride"; + return *iter > weight_sum ? *iter % weight_sum : *iter; +} + +} // namespace + +namespace brpc { +namespace policy { + +bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) { + if (bg.server_list.capacity() < 128) { + bg.server_list.reserve(128); + } + uint32_t weight = 0; + if (butil::StringToUint(id.tag, &weight) && + weight > 0) { + bool insert_server = + bg.server_map.emplace(id.id, bg.server_list.size()).second; + if (insert_server) { + bg.server_list.emplace_back(id.id, weight); + bg.weight_sum += weight; + return true; + } + } else { + LOG(ERROR) << "Invalid weight is set: " << id.tag; + } + return false; +} + +bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) { + auto iter = bg.server_map.find(id.id); + if (iter != bg.server_map.end()) { + const size_t index = iter->second; + bg.weight_sum -= bg.server_list[index].weight; + bg.server_list[index] = bg.server_list.back(); + bg.server_map[bg.server_list[index].id] = index; + bg.server_list.pop_back(); + bg.server_map.erase(iter); + return true; + } + return false; +} + +size_t WeightedRoundRobinLoadBalancer::BatchAdd( + Servers& bg, const std::vector<ServerId>& servers) { + size_t count = 0; + for (size_t i = 0; i < servers.size(); ++i) { + count += !!Add(bg, servers[i]); + } + return count; +} + +size_t WeightedRoundRobinLoadBalancer::BatchRemove( + Servers& bg, const std::vector<ServerId>& servers) { + size_t count = 0; + for (size_t i = 0; i < servers.size(); ++i) { + count += !!Remove(bg, servers[i]); + } + return count; +} + +bool WeightedRoundRobinLoadBalancer::AddServer(const ServerId& id) { + return _db_servers.Modify(Add, id); +} + +bool WeightedRoundRobinLoadBalancer::RemoveServer(const ServerId& id) { + return _db_servers.Modify(Remove, id); +} + +size_t WeightedRoundRobinLoadBalancer::AddServersInBatch( + const std::vector<ServerId>& servers) { + const size_t n = _db_servers.Modify(BatchAdd, servers); + LOG_IF(ERROR, n != servers.size()) + << "Fail to AddServersInBatch, expected " << servers.size() + << " actually " << n; + return n; +} + +size_t WeightedRoundRobinLoadBalancer::RemoveServersInBatch( + const std::vector<ServerId>& servers) { + const size_t n = _db_servers.Modify(BatchRemove, servers); + LOG_IF(ERROR, n != servers.size()) + << "Fail to RemoveServersInBatch, expected " << servers.size() + << " actually " << n; + return n; +} + +int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) { + butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s; + if (_db_servers.Read(&s) != 0) { + return ENOMEM; + } + if (s->server_list.empty()) { + return ENODATA; + } + TLS& tls = s.tls(); + if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) { + if (tls.stride == 0) { + tls.position = butil::fast_rand_less_than(s->server_list.size()); + } + tls.stride = GetStride(s->weight_sum, s->server_list.size()); + } + // If server list changed, the position may be out of range. + tls.position %= s->server_list.size(); + // Check whether remain server was removed from server list. + if (tls.remain_server.weight > 0 && + tls.remain_server.id != s->server_list[tls.position].id) { + tls.remain_server.weight = 0; + } + for (uint64_t i = 0; i != tls.stride; ++i) { + SocketId server_id = GetServerInNextStride(s->server_list, tls); + if (!ExcludedServers::IsExcluded(in.excluded, server_id) + && Socket::Address(server_id, out->ptr) == 0 + && !(*out->ptr)->IsLogOff()) { + return 0; + } + } + return EHOSTDOWN; +} + +SocketId WeightedRoundRobinLoadBalancer::GetServerInNextStride( + const std::vector<Server>& server_list, TLS& tls) { + SocketId final_server = 0; + uint64_t stride = tls.stride; + if (tls.remain_server.weight > 0) { + final_server = tls.remain_server.id; + if (tls.remain_server.weight > stride) { + tls.remain_server.weight -= stride; + return final_server; + } else { + stride -= tls.remain_server.weight; + tls.remain_server.weight = 0; + ++tls.position; + tls.position %= server_list.size(); + } + } + while (stride > 0) { + uint32_t configured_weight = server_list[tls.position].weight; + final_server = server_list[tls.position].id; + if (configured_weight > stride) { + tls.remain_server.id = final_server; + tls.remain_server.weight = configured_weight - stride; + return final_server; + } + stride -= configured_weight; + ++tls.position; + tls.position %= server_list.size(); + } + return final_server; +} + +LoadBalancer* WeightedRoundRobinLoadBalancer::New() const { + return new (std::nothrow) WeightedRoundRobinLoadBalancer; +} + +void WeightedRoundRobinLoadBalancer::Destroy() { + delete this; +} + +void WeightedRoundRobinLoadBalancer::Describe( + std::ostream &os, const DescribeOptions& options) { + if (!options.verbose) { + os << "wrr"; + return; + } + os << "WeightedRoundRobin{"; + butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s; + if (_db_servers.Read(&s) != 0) { + os << "fail to read _db_servers"; + } else { + os << "n=" << s->server_list.size() << ':'; + for (const auto& server : s->server_list) { + os << ' ' << server.id << '(' << server.weight << ')'; + } + } + os << '}'; +} + +} // namespace policy +} // namespace brpc diff --git a/src/brpc/policy/weighted_round_robin_load_balancer.h b/src/brpc/policy/weighted_round_robin_load_balancer.h new file mode 100644 index 0000000000000000000000000000000000000000..c22f877ae537a5580d86885a1240056f6070eae0 --- /dev/null +++ b/src/brpc/policy/weighted_round_robin_load_balancer.h @@ -0,0 +1,86 @@ +// Copyright (c) 2018 Iqiyi, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Authors: Daojin Cai (caidaojin@qiyi.com) + +#ifndef BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H +#define BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H + +#include <map> +#include <vector> +#include "butil/containers/doubly_buffered_data.h" +#include "brpc/load_balancer.h" + +namespace brpc { +namespace policy { + +// This LoadBalancer selects server as the assigned weight. +// Weight is got from tag of ServerId. +class WeightedRoundRobinLoadBalancer : public LoadBalancer { +public: + bool AddServer(const ServerId& id); + bool RemoveServer(const ServerId& id); + size_t AddServersInBatch(const std::vector<ServerId>& servers); + size_t RemoveServersInBatch(const std::vector<ServerId>& servers); + int SelectServer(const SelectIn& in, SelectOut* out); + LoadBalancer* New() const; + void Destroy(); + void Describe(std::ostream&, const DescribeOptions& options); + +private: + struct Server { + Server(SocketId s_id = 0, uint32_t s_w = 0): id(s_id), weight(s_w) {} + SocketId id; + uint32_t weight; + }; + struct Servers { + // The value is configured weight for each server. + std::vector<Server> server_list; + // The value is the index of the server in "server_list". + std::map<SocketId, size_t> server_map; + uint64_t weight_sum = 0; + }; + struct TLS { + size_t position = 0; + uint64_t stride = 0; + Server remain_server; + // If server list changed, we need caculate a new stride. + bool IsNeededCaculateNewStride(const uint64_t curr_weight_sum, + const size_t curr_servers_num) { + if (curr_weight_sum != weight_sum + || curr_servers_num != servers_num) { + weight_sum = curr_weight_sum; + servers_num = curr_servers_num; + return true; + } + return false; + } + private: + uint64_t weight_sum = 0; + size_t servers_num = 0; + }; + static bool Add(Servers& bg, const ServerId& id); + static bool Remove(Servers& bg, const ServerId& id); + static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers); + static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); + static SocketId GetServerInNextStride(const std::vector<Server>& server_list, + TLS& tls); + + butil::DoublyBufferedData<Servers, TLS> _db_servers; +}; + +} // namespace policy +} // namespace brpc + +#endif // BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H diff --git a/test/brpc_load_balancer_unittest.cpp b/test/brpc_load_balancer_unittest.cpp index b0b8882e8de7305c20c717c4f097cd28bb864266..8c61f245ba0b7ab3d511cfad233d5ad98e9da33c 100644 --- a/test/brpc_load_balancer_unittest.cpp +++ b/test/brpc_load_balancer_unittest.cpp @@ -10,7 +10,10 @@ #include "butil/gperftools_profiler.h" #include "butil/time.h" #include "butil/containers/doubly_buffered_data.h" +#include "brpc/describable.h" #include "brpc/socket.h" +#include "butil/strings/string_number_conversions.h" +#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "brpc/policy/round_robin_load_balancer.h" #include "brpc/policy/randomized_load_balancer.h" #include "brpc/policy/locality_aware_load_balancer.h" @@ -231,7 +234,7 @@ class SaveRecycle : public brpc::SocketUser { }; TEST_F(LoadBalancerTest, update_while_selection) { - for (size_t round = 0; round < 4; ++round) { + for (size_t round = 0; round < 5; ++round) { brpc::LoadBalancer* lb = NULL; SelectArg sa = { NULL, NULL}; bool is_lalb = false; @@ -242,6 +245,8 @@ TEST_F(LoadBalancerTest, update_while_selection) { } else if (round == 2) { lb = new LALB; is_lalb = true; + } else if (round == 3) { + lb = new brpc::policy::WeightedRoundRobinLoadBalancer; } else { lb = new brpc::policy::ConsistentHashingLoadBalancer( ::brpc::policy::MurmurHash32); @@ -265,6 +270,9 @@ TEST_F(LoadBalancerTest, update_while_selection) { butil::EndPoint dummy; ASSERT_EQ(0, str2endpoint(addr, &dummy)); brpc::ServerId id(8888); + if (3 == round) { + id.tag = "1"; + } brpc::SocketOptions options; options.remote_side = dummy; options.user = new SaveRecycle; @@ -342,7 +350,7 @@ TEST_F(LoadBalancerTest, update_while_selection) { } TEST_F(LoadBalancerTest, fairness) { - for (size_t round = 0; round < 4; ++round) { + for (size_t round = 0; round < 6; ++round) { brpc::LoadBalancer* lb = NULL; SelectArg sa = { NULL, NULL}; if (round == 0) { @@ -351,6 +359,8 @@ TEST_F(LoadBalancerTest, fairness) { lb = new brpc::policy::RandomizedLoadBalancer; } else if (round == 2) { lb = new LALB; + } else if (3 == round || 4 == round) { + lb = new brpc::policy::WeightedRoundRobinLoadBalancer; } else { lb = new brpc::policy::ConsistentHashingLoadBalancer( brpc::policy::MurmurHash32); @@ -375,6 +385,15 @@ TEST_F(LoadBalancerTest, fairness) { butil::EndPoint dummy; ASSERT_EQ(0, str2endpoint(addr, &dummy)); brpc::ServerId id(8888); + if (3 == round) { + id.tag = "100"; + } else if (4 == round) { + if ( i % 50 == 0) { + id.tag = std::to_string(i*2 + butil::fast_rand_less_than(40) + 80); + } else { + id.tag = std::to_string(butil::fast_rand_less_than(40) + 80); + } + } brpc::SocketOptions options; options.remote_side = dummy; options.user = new SaveRecycle; @@ -418,18 +437,41 @@ TEST_F(LoadBalancerTest, fairness) { size_t count_sum = 0; size_t count_squared_sum = 0; std::cout << lb_name << ':' << '\n'; - for (size_t i = 0; i < ids.size(); ++i) { - size_t count = total_count[ids[i].id]; - ASSERT_NE(0ul, count) << "i=" << i; - std::cout << i << '=' << count << ' '; - count_sum += count; - count_squared_sum += count * count; - } - std::cout << '\n' - << ": average=" << count_sum/ids.size() - << " deviation=" << sqrt(count_squared_sum * ids.size() - count_sum * count_sum) / ids.size() << std::endl; - + if (round != 3 && round !=4) { + for (size_t i = 0; i < ids.size(); ++i) { + size_t count = total_count[ids[i].id]; + ASSERT_NE(0ul, count) << "i=" << i; + std::cout << i << '=' << count << ' '; + count_sum += count; + count_squared_sum += count * count; + } + + std::cout << '\n' + << ": average=" << count_sum/ids.size() + << " deviation=" << sqrt(count_squared_sum * ids.size() + - count_sum * count_sum) / ids.size() << std::endl; + } else { // for weighted round robin load balancer + std::cout << "configured weight: " << std::endl; + std::ostringstream os; + brpc::DescribeOptions opt; + lb->Describe(os, opt); + std::cout << os.str() << std::endl; + double scaling_count_sum = 0.0; + double scaling_count_squared_sum = 0.0; + for (size_t i = 0; i < ids.size(); ++i) { + size_t count = total_count[ids[i].id]; + ASSERT_NE(0ul, count) << "i=" << i; + std::cout << i << '=' << count << ' '; + double scaling_count = static_cast<double>(count) / std::stoi(ids[i].tag); + scaling_count_sum += scaling_count; + scaling_count_squared_sum += scaling_count * scaling_count; + } + std::cout << '\n' + << ": scaling average=" << scaling_count_sum/ids.size() + << " scaling deviation=" << sqrt(scaling_count_squared_sum * ids.size() + - scaling_count_sum * scaling_count_sum) / ids.size() << std::endl; + } for (size_t i = 0; i < ids.size(); ++i) { ASSERT_EQ(0, brpc::Socket::SetFailed(ids[i].id)); } @@ -513,4 +555,68 @@ TEST_F(LoadBalancerTest, consistent_hashing) { } } } + +TEST_F(LoadBalancerTest, weighted_round_robin) { + const char* servers[] = { + "10.92.115.19:8831", + "10.42.108.25:8832", + "10.36.150.32:8833", + "10.92.149.48:8834", + "10.42.122.201:8835", + "10.42.122.202:8836" + }; + std::string weight[] = {"3", "2", "7", "1ab", "-1", "0"}; + std::map<butil::EndPoint, int> configed_weight; + brpc::policy::WeightedRoundRobinLoadBalancer wrrlb; + + // Add server to selected list. The server with invalid weight will be skipped. + for (size_t i = 0; i < ARRAY_SIZE(servers); ++i) { + const char *addr = servers[i]; + butil::EndPoint dummy; + ASSERT_EQ(0, str2endpoint(addr, &dummy)); + brpc::ServerId id(8888); + brpc::SocketOptions options; + options.remote_side = dummy; + options.user = new SaveRecycle; + ASSERT_EQ(0, brpc::Socket::Create(options, &id.id)); + id.tag = weight[i]; + if ( i < 3 ) { + int weight_num = 0; + ASSERT_TRUE(butil::StringToInt(weight[i], &weight_num)); + configed_weight[dummy] = weight_num; + EXPECT_TRUE(wrrlb.AddServer(id)); + } else { + EXPECT_FALSE(wrrlb.AddServer(id)); + } + } + + // Select the best server according to weight configured. + // There are 3 valid servers with weight 3, 2 and 7 respectively. + // We run SelectServer for 12 times. The result number of each server seleted should be + // consistent with weight configured. + std::map<butil::EndPoint, size_t> select_result; + brpc::SocketUniquePtr ptr; + brpc::LoadBalancer::SelectIn in = { 0, false, false, 0u, NULL }; + brpc::LoadBalancer::SelectOut out(&ptr); + int total_weight = 12; + std::vector<butil::EndPoint> select_servers; + for (int i = 0; i != total_weight; ++i) { + EXPECT_EQ(0, wrrlb.SelectServer(in, &out)); + select_servers.emplace_back(ptr->remote_side()); + ++select_result[ptr->remote_side()]; + } + + for (const auto& s : select_servers) { + std::cout << "1=" << s << ", "; + } + std::cout << std::endl; + // Check whether slected result is consistent with expected. + EXPECT_EQ(3, select_result.size()); + for (const auto& result : select_result) { + std::cout << result.first << " result=" << result.second + << " configured=" << configed_weight[result.first] << std::endl; + EXPECT_EQ(result.second, configed_weight[result.first]); + } +} + } //namespace