Unverified Commit 0c90dbc0 authored by Ge Jun's avatar Ge Jun Committed by GitHub

Merge pull request #243 from cdjingit/master

weighted round robin load balancer
parents 98755bbb 675197a0
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
// Load Balancers // Load Balancers
#include "brpc/policy/round_robin_load_balancer.h" #include "brpc/policy/round_robin_load_balancer.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h"
#include "brpc/policy/randomized_load_balancer.h" #include "brpc/policy/randomized_load_balancer.h"
#include "brpc/policy/locality_aware_load_balancer.h" #include "brpc/policy/locality_aware_load_balancer.h"
#include "brpc/policy/consistent_hashing_load_balancer.h" #include "brpc/policy/consistent_hashing_load_balancer.h"
...@@ -106,6 +107,7 @@ struct GlobalExtensions { ...@@ -106,6 +107,7 @@ struct GlobalExtensions {
RemoteFileNamingService rfns; RemoteFileNamingService rfns;
RoundRobinLoadBalancer rr_lb; RoundRobinLoadBalancer rr_lb;
WeightedRoundRobinLoadBalancer wrr_lb;
RandomizedLoadBalancer randomized_lb; RandomizedLoadBalancer randomized_lb;
LocalityAwareLoadBalancer la_lb; LocalityAwareLoadBalancer la_lb;
ConsistentHashingLoadBalancer ch_mh_lb; ConsistentHashingLoadBalancer ch_mh_lb;
...@@ -318,6 +320,7 @@ static void GlobalInitializeOrDieImpl() { ...@@ -318,6 +320,7 @@ static void GlobalInitializeOrDieImpl() {
// Load Balancers // Load Balancers
LoadBalancerExtension()->RegisterOrDie("rr", &g_ext->rr_lb); LoadBalancerExtension()->RegisterOrDie("rr", &g_ext->rr_lb);
LoadBalancerExtension()->RegisterOrDie("wrr", &g_ext->wrr_lb);
LoadBalancerExtension()->RegisterOrDie("random", &g_ext->randomized_lb); LoadBalancerExtension()->RegisterOrDie("random", &g_ext->randomized_lb);
LoadBalancerExtension()->RegisterOrDie("la", &g_ext->la_lb); LoadBalancerExtension()->RegisterOrDie("la", &g_ext->la_lb);
LoadBalancerExtension()->RegisterOrDie("c_murmurhash", &g_ext->ch_mh_lb); LoadBalancerExtension()->RegisterOrDie("c_murmurhash", &g_ext->ch_mh_lb);
......
// Copyright (c) 2018 Iqiyi, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Authors: Daojin Cai (caidaojin@qiyi.com)
#include <algorithm>
#include "butil/fast_rand.h"
#include "brpc/socket.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h"
#include "butil/strings/string_number_conversions.h"
namespace {
const std::vector<uint64_t> prime_stride = {
2,3,5,11,17,29,47,71,107,137,163,251,307,379,569,683,857,1289,1543,1949,2617,
2927,3407,4391,6599,9901,14867,22303,33457,50207,75323,112997,169501,254257,
381389,572087,849083,1273637,1910471,2865727,4298629,6447943,9671923,14507903,
21761863,32642861,48964297,73446469,110169743,165254623,247881989,371822987,
557734537,836601847,1254902827,1882354259,2823531397,4235297173,6352945771,
9529418671};
bool IsCoprime(uint64_t num1, uint64_t num2) {
uint64_t temp;
if (num1 < num2) {
temp = num1;
num1 = num2;
num2 = temp;
}
while (true) {
temp = num1 % num2;
if (temp == 0) {
break;
} else {
num1 = num2;
num2 = temp;
}
}
return num2 == 1;
}
// Get a reasonable stride according to weights configured of servers.
uint64_t GetStride(const uint64_t weight_sum, const size_t num) {
if (weight_sum == 1) {
return 1;
}
uint32_t average_weight = weight_sum / num;
auto iter = std::lower_bound(prime_stride.begin(), prime_stride.end(),
average_weight);
while (iter != prime_stride.end()
&& !IsCoprime(weight_sum, *iter)) {
++iter;
}
CHECK(iter != prime_stride.end()) << "Failed to get stride";
return *iter > weight_sum ? *iter % weight_sum : *iter;
}
} // namespace
namespace brpc {
namespace policy {
bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) {
if (bg.server_list.capacity() < 128) {
bg.server_list.reserve(128);
}
uint32_t weight = 0;
if (butil::StringToUint(id.tag, &weight) &&
weight > 0) {
bool insert_server =
bg.server_map.emplace(id.id, bg.server_list.size()).second;
if (insert_server) {
bg.server_list.emplace_back(id.id, weight);
bg.weight_sum += weight;
return true;
}
} else {
LOG(ERROR) << "Invalid weight is set: " << id.tag;
}
return false;
}
bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) {
auto iter = bg.server_map.find(id.id);
if (iter != bg.server_map.end()) {
const size_t index = iter->second;
bg.weight_sum -= bg.server_list[index].weight;
bg.server_list[index] = bg.server_list.back();
bg.server_map[bg.server_list[index].id] = index;
bg.server_list.pop_back();
bg.server_map.erase(iter);
return true;
}
return false;
}
size_t WeightedRoundRobinLoadBalancer::BatchAdd(
Servers& bg, const std::vector<ServerId>& servers) {
size_t count = 0;
for (size_t i = 0; i < servers.size(); ++i) {
count += !!Add(bg, servers[i]);
}
return count;
}
size_t WeightedRoundRobinLoadBalancer::BatchRemove(
Servers& bg, const std::vector<ServerId>& servers) {
size_t count = 0;
for (size_t i = 0; i < servers.size(); ++i) {
count += !!Remove(bg, servers[i]);
}
return count;
}
bool WeightedRoundRobinLoadBalancer::AddServer(const ServerId& id) {
return _db_servers.Modify(Add, id);
}
bool WeightedRoundRobinLoadBalancer::RemoveServer(const ServerId& id) {
return _db_servers.Modify(Remove, id);
}
size_t WeightedRoundRobinLoadBalancer::AddServersInBatch(
const std::vector<ServerId>& servers) {
const size_t n = _db_servers.Modify(BatchAdd, servers);
LOG_IF(ERROR, n != servers.size())
<< "Fail to AddServersInBatch, expected " << servers.size()
<< " actually " << n;
return n;
}
size_t WeightedRoundRobinLoadBalancer::RemoveServersInBatch(
const std::vector<ServerId>& servers) {
const size_t n = _db_servers.Modify(BatchRemove, servers);
LOG_IF(ERROR, n != servers.size())
<< "Fail to RemoveServersInBatch, expected " << servers.size()
<< " actually " << n;
return n;
}
int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) {
butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
if (_db_servers.Read(&s) != 0) {
return ENOMEM;
}
if (s->server_list.empty()) {
return ENODATA;
}
TLS& tls = s.tls();
if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) {
if (tls.stride == 0) {
tls.position = butil::fast_rand_less_than(s->server_list.size());
}
tls.stride = GetStride(s->weight_sum, s->server_list.size());
}
// If server list changed, the position may be out of range.
tls.position %= s->server_list.size();
// Check whether remain server was removed from server list.
if (tls.remain_server.weight > 0 &&
tls.remain_server.id != s->server_list[tls.position].id) {
tls.remain_server.weight = 0;
}
for (uint64_t i = 0; i != tls.stride; ++i) {
SocketId server_id = GetServerInNextStride(s->server_list, tls);
if (!ExcludedServers::IsExcluded(in.excluded, server_id)
&& Socket::Address(server_id, out->ptr) == 0
&& !(*out->ptr)->IsLogOff()) {
return 0;
}
}
return EHOSTDOWN;
}
SocketId WeightedRoundRobinLoadBalancer::GetServerInNextStride(
const std::vector<Server>& server_list, TLS& tls) {
SocketId final_server = 0;
uint64_t stride = tls.stride;
if (tls.remain_server.weight > 0) {
final_server = tls.remain_server.id;
if (tls.remain_server.weight > stride) {
tls.remain_server.weight -= stride;
return final_server;
} else {
stride -= tls.remain_server.weight;
tls.remain_server.weight = 0;
++tls.position;
tls.position %= server_list.size();
}
}
while (stride > 0) {
uint32_t configured_weight = server_list[tls.position].weight;
final_server = server_list[tls.position].id;
if (configured_weight > stride) {
tls.remain_server.id = final_server;
tls.remain_server.weight = configured_weight - stride;
return final_server;
}
stride -= configured_weight;
++tls.position;
tls.position %= server_list.size();
}
return final_server;
}
LoadBalancer* WeightedRoundRobinLoadBalancer::New() const {
return new (std::nothrow) WeightedRoundRobinLoadBalancer;
}
void WeightedRoundRobinLoadBalancer::Destroy() {
delete this;
}
void WeightedRoundRobinLoadBalancer::Describe(
std::ostream &os, const DescribeOptions& options) {
if (!options.verbose) {
os << "wrr";
return;
}
os << "WeightedRoundRobin{";
butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
if (_db_servers.Read(&s) != 0) {
os << "fail to read _db_servers";
} else {
os << "n=" << s->server_list.size() << ':';
for (const auto& server : s->server_list) {
os << ' ' << server.id << '(' << server.weight << ')';
}
}
os << '}';
}
} // namespace policy
} // namespace brpc
// Copyright (c) 2018 Iqiyi, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Authors: Daojin Cai (caidaojin@qiyi.com)
#ifndef BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
#define BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
#include <map>
#include <vector>
#include "butil/containers/doubly_buffered_data.h"
#include "brpc/load_balancer.h"
namespace brpc {
namespace policy {
// This LoadBalancer selects server as the assigned weight.
// Weight is got from tag of ServerId.
class WeightedRoundRobinLoadBalancer : public LoadBalancer {
public:
bool AddServer(const ServerId& id);
bool RemoveServer(const ServerId& id);
size_t AddServersInBatch(const std::vector<ServerId>& servers);
size_t RemoveServersInBatch(const std::vector<ServerId>& servers);
int SelectServer(const SelectIn& in, SelectOut* out);
LoadBalancer* New() const;
void Destroy();
void Describe(std::ostream&, const DescribeOptions& options);
private:
struct Server {
Server(SocketId s_id = 0, uint32_t s_w = 0): id(s_id), weight(s_w) {}
SocketId id;
uint32_t weight;
};
struct Servers {
// The value is configured weight for each server.
std::vector<Server> server_list;
// The value is the index of the server in "server_list".
std::map<SocketId, size_t> server_map;
uint64_t weight_sum = 0;
};
struct TLS {
size_t position = 0;
uint64_t stride = 0;
Server remain_server;
// If server list changed, we need caculate a new stride.
bool IsNeededCaculateNewStride(const uint64_t curr_weight_sum,
const size_t curr_servers_num) {
if (curr_weight_sum != weight_sum
|| curr_servers_num != servers_num) {
weight_sum = curr_weight_sum;
servers_num = curr_servers_num;
return true;
}
return false;
}
private:
uint64_t weight_sum = 0;
size_t servers_num = 0;
};
static bool Add(Servers& bg, const ServerId& id);
static bool Remove(Servers& bg, const ServerId& id);
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static SocketId GetServerInNextStride(const std::vector<Server>& server_list,
TLS& tls);
butil::DoublyBufferedData<Servers, TLS> _db_servers;
};
} // namespace policy
} // namespace brpc
#endif // BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
...@@ -10,7 +10,10 @@ ...@@ -10,7 +10,10 @@
#include "butil/gperftools_profiler.h" #include "butil/gperftools_profiler.h"
#include "butil/time.h" #include "butil/time.h"
#include "butil/containers/doubly_buffered_data.h" #include "butil/containers/doubly_buffered_data.h"
#include "brpc/describable.h"
#include "brpc/socket.h" #include "brpc/socket.h"
#include "butil/strings/string_number_conversions.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h"
#include "brpc/policy/round_robin_load_balancer.h" #include "brpc/policy/round_robin_load_balancer.h"
#include "brpc/policy/randomized_load_balancer.h" #include "brpc/policy/randomized_load_balancer.h"
#include "brpc/policy/locality_aware_load_balancer.h" #include "brpc/policy/locality_aware_load_balancer.h"
...@@ -231,7 +234,7 @@ class SaveRecycle : public brpc::SocketUser { ...@@ -231,7 +234,7 @@ class SaveRecycle : public brpc::SocketUser {
}; };
TEST_F(LoadBalancerTest, update_while_selection) { TEST_F(LoadBalancerTest, update_while_selection) {
for (size_t round = 0; round < 4; ++round) { for (size_t round = 0; round < 5; ++round) {
brpc::LoadBalancer* lb = NULL; brpc::LoadBalancer* lb = NULL;
SelectArg sa = { NULL, NULL}; SelectArg sa = { NULL, NULL};
bool is_lalb = false; bool is_lalb = false;
...@@ -242,6 +245,8 @@ TEST_F(LoadBalancerTest, update_while_selection) { ...@@ -242,6 +245,8 @@ TEST_F(LoadBalancerTest, update_while_selection) {
} else if (round == 2) { } else if (round == 2) {
lb = new LALB; lb = new LALB;
is_lalb = true; is_lalb = true;
} else if (round == 3) {
lb = new brpc::policy::WeightedRoundRobinLoadBalancer;
} else { } else {
lb = new brpc::policy::ConsistentHashingLoadBalancer( lb = new brpc::policy::ConsistentHashingLoadBalancer(
::brpc::policy::MurmurHash32); ::brpc::policy::MurmurHash32);
...@@ -265,6 +270,9 @@ TEST_F(LoadBalancerTest, update_while_selection) { ...@@ -265,6 +270,9 @@ TEST_F(LoadBalancerTest, update_while_selection) {
butil::EndPoint dummy; butil::EndPoint dummy;
ASSERT_EQ(0, str2endpoint(addr, &dummy)); ASSERT_EQ(0, str2endpoint(addr, &dummy));
brpc::ServerId id(8888); brpc::ServerId id(8888);
if (3 == round) {
id.tag = "1";
}
brpc::SocketOptions options; brpc::SocketOptions options;
options.remote_side = dummy; options.remote_side = dummy;
options.user = new SaveRecycle; options.user = new SaveRecycle;
...@@ -342,7 +350,7 @@ TEST_F(LoadBalancerTest, update_while_selection) { ...@@ -342,7 +350,7 @@ TEST_F(LoadBalancerTest, update_while_selection) {
} }
TEST_F(LoadBalancerTest, fairness) { TEST_F(LoadBalancerTest, fairness) {
for (size_t round = 0; round < 4; ++round) { for (size_t round = 0; round < 6; ++round) {
brpc::LoadBalancer* lb = NULL; brpc::LoadBalancer* lb = NULL;
SelectArg sa = { NULL, NULL}; SelectArg sa = { NULL, NULL};
if (round == 0) { if (round == 0) {
...@@ -351,6 +359,8 @@ TEST_F(LoadBalancerTest, fairness) { ...@@ -351,6 +359,8 @@ TEST_F(LoadBalancerTest, fairness) {
lb = new brpc::policy::RandomizedLoadBalancer; lb = new brpc::policy::RandomizedLoadBalancer;
} else if (round == 2) { } else if (round == 2) {
lb = new LALB; lb = new LALB;
} else if (3 == round || 4 == round) {
lb = new brpc::policy::WeightedRoundRobinLoadBalancer;
} else { } else {
lb = new brpc::policy::ConsistentHashingLoadBalancer( lb = new brpc::policy::ConsistentHashingLoadBalancer(
brpc::policy::MurmurHash32); brpc::policy::MurmurHash32);
...@@ -375,6 +385,15 @@ TEST_F(LoadBalancerTest, fairness) { ...@@ -375,6 +385,15 @@ TEST_F(LoadBalancerTest, fairness) {
butil::EndPoint dummy; butil::EndPoint dummy;
ASSERT_EQ(0, str2endpoint(addr, &dummy)); ASSERT_EQ(0, str2endpoint(addr, &dummy));
brpc::ServerId id(8888); brpc::ServerId id(8888);
if (3 == round) {
id.tag = "100";
} else if (4 == round) {
if ( i % 50 == 0) {
id.tag = std::to_string(i*2 + butil::fast_rand_less_than(40) + 80);
} else {
id.tag = std::to_string(butil::fast_rand_less_than(40) + 80);
}
}
brpc::SocketOptions options; brpc::SocketOptions options;
options.remote_side = dummy; options.remote_side = dummy;
options.user = new SaveRecycle; options.user = new SaveRecycle;
...@@ -418,18 +437,41 @@ TEST_F(LoadBalancerTest, fairness) { ...@@ -418,18 +437,41 @@ TEST_F(LoadBalancerTest, fairness) {
size_t count_sum = 0; size_t count_sum = 0;
size_t count_squared_sum = 0; size_t count_squared_sum = 0;
std::cout << lb_name << ':' << '\n'; std::cout << lb_name << ':' << '\n';
for (size_t i = 0; i < ids.size(); ++i) {
size_t count = total_count[ids[i].id];
ASSERT_NE(0ul, count) << "i=" << i;
std::cout << i << '=' << count << ' ';
count_sum += count;
count_squared_sum += count * count;
}
std::cout << '\n' if (round != 3 && round !=4) {
<< ": average=" << count_sum/ids.size() for (size_t i = 0; i < ids.size(); ++i) {
<< " deviation=" << sqrt(count_squared_sum * ids.size() - count_sum * count_sum) / ids.size() << std::endl; size_t count = total_count[ids[i].id];
ASSERT_NE(0ul, count) << "i=" << i;
std::cout << i << '=' << count << ' ';
count_sum += count;
count_squared_sum += count * count;
}
std::cout << '\n'
<< ": average=" << count_sum/ids.size()
<< " deviation=" << sqrt(count_squared_sum * ids.size()
- count_sum * count_sum) / ids.size() << std::endl;
} else { // for weighted round robin load balancer
std::cout << "configured weight: " << std::endl;
std::ostringstream os;
brpc::DescribeOptions opt;
lb->Describe(os, opt);
std::cout << os.str() << std::endl;
double scaling_count_sum = 0.0;
double scaling_count_squared_sum = 0.0;
for (size_t i = 0; i < ids.size(); ++i) {
size_t count = total_count[ids[i].id];
ASSERT_NE(0ul, count) << "i=" << i;
std::cout << i << '=' << count << ' ';
double scaling_count = static_cast<double>(count) / std::stoi(ids[i].tag);
scaling_count_sum += scaling_count;
scaling_count_squared_sum += scaling_count * scaling_count;
}
std::cout << '\n'
<< ": scaling average=" << scaling_count_sum/ids.size()
<< " scaling deviation=" << sqrt(scaling_count_squared_sum * ids.size()
- scaling_count_sum * scaling_count_sum) / ids.size() << std::endl;
}
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
ASSERT_EQ(0, brpc::Socket::SetFailed(ids[i].id)); ASSERT_EQ(0, brpc::Socket::SetFailed(ids[i].id));
} }
...@@ -513,4 +555,68 @@ TEST_F(LoadBalancerTest, consistent_hashing) { ...@@ -513,4 +555,68 @@ TEST_F(LoadBalancerTest, consistent_hashing) {
} }
} }
} }
TEST_F(LoadBalancerTest, weighted_round_robin) {
const char* servers[] = {
"10.92.115.19:8831",
"10.42.108.25:8832",
"10.36.150.32:8833",
"10.92.149.48:8834",
"10.42.122.201:8835",
"10.42.122.202:8836"
};
std::string weight[] = {"3", "2", "7", "1ab", "-1", "0"};
std::map<butil::EndPoint, int> configed_weight;
brpc::policy::WeightedRoundRobinLoadBalancer wrrlb;
// Add server to selected list. The server with invalid weight will be skipped.
for (size_t i = 0; i < ARRAY_SIZE(servers); ++i) {
const char *addr = servers[i];
butil::EndPoint dummy;
ASSERT_EQ(0, str2endpoint(addr, &dummy));
brpc::ServerId id(8888);
brpc::SocketOptions options;
options.remote_side = dummy;
options.user = new SaveRecycle;
ASSERT_EQ(0, brpc::Socket::Create(options, &id.id));
id.tag = weight[i];
if ( i < 3 ) {
int weight_num = 0;
ASSERT_TRUE(butil::StringToInt(weight[i], &weight_num));
configed_weight[dummy] = weight_num;
EXPECT_TRUE(wrrlb.AddServer(id));
} else {
EXPECT_FALSE(wrrlb.AddServer(id));
}
}
// Select the best server according to weight configured.
// There are 3 valid servers with weight 3, 2 and 7 respectively.
// We run SelectServer for 12 times. The result number of each server seleted should be
// consistent with weight configured.
std::map<butil::EndPoint, size_t> select_result;
brpc::SocketUniquePtr ptr;
brpc::LoadBalancer::SelectIn in = { 0, false, false, 0u, NULL };
brpc::LoadBalancer::SelectOut out(&ptr);
int total_weight = 12;
std::vector<butil::EndPoint> select_servers;
for (int i = 0; i != total_weight; ++i) {
EXPECT_EQ(0, wrrlb.SelectServer(in, &out));
select_servers.emplace_back(ptr->remote_side());
++select_result[ptr->remote_side()];
}
for (const auto& s : select_servers) {
std::cout << "1=" << s << ", ";
}
std::cout << std::endl;
// Check whether slected result is consistent with expected.
EXPECT_EQ(3, select_result.size());
for (const auto& result : select_result) {
std::cout << result.first << " result=" << result.second
<< " configured=" << configed_weight[result.first] << std::endl;
EXPECT_EQ(result.second, configed_weight[result.first]);
}
}
} //namespace } //namespace
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment