Commit d894cba7 authored by cdjin's avatar cdjin

wrr algorithm enhancement

parent 74df9729
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// Authors: Daojin Cai (caidaojin@qiyi.com) // Authors: Daojin Cai (caidaojin@qiyi.com)
#include "butil/fast_rand.h"
#include "brpc/socket.h" #include "brpc/socket.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "brpc/policy/weighted_round_robin_load_balancer.h"
#include "butil/strings/string_number_conversions.h" #include "butil/strings/string_number_conversions.h"
...@@ -21,7 +22,24 @@ ...@@ -21,7 +22,24 @@
namespace brpc { namespace brpc {
namespace policy { namespace policy {
static const int EraseBatchSize = 100; bool IsCoprime(uint32_t num1, uint32_t num2) {
uint32_t temp;
if (num1 < num2) {
temp = num1;
num1 = num2;
num2 = temp;
}
while (true) {
temp = num1 % num2;
if (temp == 0) {
break;
} else {
num1 = num2;
num2 = temp;
}
}
return num2 == 1;
}
bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) { bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) {
if (bg.server_list.capacity() < 128) { if (bg.server_list.capacity() < 128) {
...@@ -33,6 +51,7 @@ bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) { ...@@ -33,6 +51,7 @@ bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) {
bg.server_map.emplace(id.id, bg.server_list.size()).second; bg.server_map.emplace(id.id, bg.server_list.size()).second;
if (insert_server) { if (insert_server) {
bg.server_list.emplace_back(id.id, weight); bg.server_list.emplace_back(id.id, weight);
bg.weight_sum += weight;
return true; return true;
} }
} else { } else {
...@@ -45,6 +64,7 @@ bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) { ...@@ -45,6 +64,7 @@ bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) {
auto iter = bg.server_map.find(id.id); auto iter = bg.server_map.find(id.id);
if (iter != bg.server_map.end()) { if (iter != bg.server_map.end()) {
const size_t index = iter->second; const size_t index = iter->second;
bg.weight_sum -= bg.server_list[index].second;
bg.server_list[index] = bg.server_list.back(); bg.server_list[index] = bg.server_list.back();
bg.server_map[bg.server_list[index].first] = index; bg.server_map[bg.server_list[index].first] = index;
bg.server_list.pop_back(); bg.server_list.pop_back();
...@@ -107,40 +127,19 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -107,40 +127,19 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
return ENODATA; return ENODATA;
} }
TLS& tls = s.tls(); TLS& tls = s.tls();
int64_t best = -1; if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) {
int total_weight = 0; tls.stride = GetStride(s->weight_sum, s->server_list.size());
// TODO: each thread requsts service as the same sequence. tls.offset = butil::fast_rand_less_than(tls.stride);
// We can set a random beginning position for each thread. }
for (const auto& server : s->server_list) { // If server list changed, the position may be out of range.
// A new server is added or the wrr fisrt run. tls.position %= s->server_list.size();
// Add the servers into TLS. // Check whether remain server was removed from server list.
const SocketId server_id = server.first; if (tls.HasRemainServer() &&
auto iter = tls.emplace(server_id, 0).first; s->server_map.find(tls.remain_server.first) == s->server_map.end()) {
if (ExcludedServers::IsExcluded(in.excluded, server_id) tls.ResetRemainServer();
|| Socket::Address(server_id, out->ptr) != 0 }
|| (*out->ptr)->IsLogOff()) { for ( uint32_t i = 0; i != tls.stride; ++i) {
continue; int64_t best = GetBestServer(s->server_list, tls, tls.stride);
}
iter->second += server.second;
total_weight += server.second;
if (best == -1 || tls[server_id] > tls[best]) {
best = server_id;
}
}
// If too many servers were removed from _db_servers(name service),
// remove these servers from TLS.
if (s->server_list.size() + EraseBatchSize < tls.size()) {
auto iter = tls.begin();
while (iter != tls.end()) {
if (s->server_map.find(iter->first) == s->server_map.end()) {
iter = tls.erase(iter);
} else {
++iter;
}
}
}
if (best != -1) {
tls[best] -= total_weight;
if (!ExcludedServers::IsExcluded(in.excluded, best) if (!ExcludedServers::IsExcluded(in.excluded, best)
&& Socket::Address(best, out->ptr) == 0 && Socket::Address(best, out->ptr) == 0
&& !(*out->ptr)->IsLogOff()) { && !(*out->ptr)->IsLogOff()) {
...@@ -150,6 +149,80 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -150,6 +149,80 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
return EHOSTDOWN; return EHOSTDOWN;
} }
int64_t WeightedRoundRobinLoadBalancer::GetBestServer(
const std::vector<std::pair<SocketId, int>>& server_list,
TLS& tls, uint32_t stride) {
uint32_t comp_weight = 0;
int64_t final_server = -1;
while (stride > 0) {
if (tls.HasRemainServer()) {
uint32_t remain_weight = tls.remain_server.second;
if (remain_weight < stride) {
TryToGetFinalServer(tls, tls.remain_server,
comp_weight, &final_server);
tls.ResetRemainServer();
stride -= remain_weight;
} else if (remain_weight == stride) {
TryToGetFinalServer(tls, tls.remain_server,
comp_weight, &final_server);
tls.ResetRemainServer();
break;
} else {
TryToGetFinalServer(tls,
std::pair<SocketId, int>(tls.remain_server.first, stride),
comp_weight, &final_server);
tls.remain_server.second -= stride;
break;
}
} else {
uint32_t weight = server_list[tls.position].second;
if (weight < stride) {
TryToGetFinalServer(tls, server_list[tls.position],
comp_weight, &final_server);
stride -= weight;
tls.UpdatePosition(server_list.size());
} else if (weight == stride) {
TryToGetFinalServer(tls, server_list[tls.position],
comp_weight, &final_server);
tls.UpdatePosition(server_list.size());
break;
} else {
TryToGetFinalServer(tls,
std::pair<SocketId, int>(
server_list[tls.position].first, stride),
comp_weight, &final_server);
tls.SetRemainServer(server_list[tls.position].first,
weight - stride);
tls.UpdatePosition(server_list.size());
break;
}
}
}
return final_server;
}
uint32_t WeightedRoundRobinLoadBalancer::GetStride(
const uint32_t weight_sum, const uint32_t num) {
uint32_t average_weight = weight_sum / num;
// The stride is the first number which is greater than or equal to
// average weight and coprime to weight_sum.
while (!IsCoprime(weight_sum, average_weight)) {
++average_weight;
}
return average_weight;
}
void WeightedRoundRobinLoadBalancer::TryToGetFinalServer(
const TLS& tls, const std::pair<SocketId, int> server,
uint32_t& comp_weight, int64_t* final_server) {
if (*final_server == -1) {
comp_weight += server.second;
if (comp_weight >= tls.offset) {
*final_server = server.first;
}
}
}
LoadBalancer* WeightedRoundRobinLoadBalancer::New() const { LoadBalancer* WeightedRoundRobinLoadBalancer::New() const {
return new (std::nothrow) WeightedRoundRobinLoadBalancer; return new (std::nothrow) WeightedRoundRobinLoadBalancer;
} }
......
...@@ -44,14 +44,55 @@ private: ...@@ -44,14 +44,55 @@ private:
std::vector<std::pair<SocketId, int>> server_list; std::vector<std::pair<SocketId, int>> server_list;
// The value is the index of the server in "server_list". // The value is the index of the server in "server_list".
std::map<SocketId, size_t> server_map; std::map<SocketId, size_t> server_map;
uint32_t weight_sum = 0;
};
struct TLS {
TLS(): remain_server(0, 0) { }
uint32_t position = 0;
uint32_t stride = 0;
uint32_t offset = 0;
std::pair<SocketId, int> remain_server;
bool HasRemainServer() const {
return remain_server.second != 0;
}
void SetRemainServer(const SocketId id, const int weight) {
remain_server.first = id;
remain_server.second = weight;
}
void ResetRemainServer() {
remain_server.second = 0;
}
void UpdatePosition(const uint32_t size) {
++position;
position %= size;
}
// If server list changed, we need caculate a new stride.
bool IsNeededCaculateNewStride(const uint32_t curr_weight_sum,
const uint32_t curr_servers_num) {
if (curr_weight_sum != weight_sum
|| curr_servers_num != servers_num) {
weight_sum = curr_weight_sum;
servers_num = curr_servers_num;
return true;
}
return false;
}
private:
uint32_t weight_sum = 0;
uint32_t servers_num = 0;
}; };
// The value is current weight for a server.
// It will be changed in the selection of servers.
using TLS = std::map<SocketId, int>;
static bool Add(Servers& bg, const ServerId& id); static bool Add(Servers& bg, const ServerId& id);
static bool Remove(Servers& bg, const ServerId& id); static bool Remove(Servers& bg, const ServerId& id);
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static int64_t GetBestServer(
const std::vector<std::pair<SocketId, int>>& server_list,
TLS& tls, uint32_t stride);
// Get a reasonable stride according to weights configured of servers.
static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num);
static void TryToGetFinalServer(const TLS& tls,
const std::pair<SocketId, int> server,
uint32_t& comp_weight, int64_t* final_server);
butil::DoublyBufferedData<Servers, TLS> _db_servers; butil::DoublyBufferedData<Servers, TLS> _db_servers;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment