Commit 977d26e0 authored by cdjgit's avatar cdjgit

enhance wrr lb algo

parent c7cbfb37
......@@ -171,44 +171,69 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
tls.remain_server.id != s->server_list[tls.position].id) {
tls.remain_server.weight = 0;
}
for (uint64_t i = 0; i != tls.stride; ++i) {
SocketId server_id = GetServerInNextStride(s->server_list, tls);
// The servers that can not be choosed.
std::unordered_set<SocketId> filter;
TLS tls_temp = tls;
uint64_t remain_weight = s->weight_sum;
size_t remain_servers = s->server_list.size();
while (remain_servers > 0) {
SocketId server_id = GetServerInNextStride(s->server_list, filter, tls_temp);
if (!ExcludedServers::IsExcluded(in.excluded, server_id)
&& Socket::Address(server_id, out->ptr) == 0
&& !(*out->ptr)->IsLogOff()) {
// update tls.
tls.remain_server = tls_temp.remain_server;
tls.position = tls_temp.position;
return 0;
} else {
// Skip this invalid server. We need calculate a new stride for server selection.
filter.emplace(server_id);
remain_weight -= (s->server_list[s->server_map.at(server_id)]).weight;
--remain_servers;
// Select from begining status.
tls_temp.stride = GetStride(remain_weight, remain_servers);
tls_temp.position = tls.position;
tls_temp.remain_server = tls.remain_server;
continue;
}
}
return EHOSTDOWN;
}
SocketId WeightedRoundRobinLoadBalancer::GetServerInNextStride(
const std::vector<Server>& server_list, TLS& tls) {
SocketId final_server = 0;
const std::vector<Server>& server_list,
const std::unordered_set<SocketId>& filter,
TLS& tls) {
SocketId final_server = INVALID_STREAM_ID;
uint64_t stride = tls.stride;
if (tls.remain_server.weight > 0) {
final_server = tls.remain_server.id;
if (tls.remain_server.weight > stride) {
tls.remain_server.weight -= stride;
return final_server;
} else {
stride -= tls.remain_server.weight;
tls.remain_server.weight = 0;
++tls.position;
tls.position %= server_list.size();
Server& remain = tls.remain_server;
if (remain.weight > 0) {
if (filter.count(remain.id) == 0) {
final_server = remain.id;
if (remain.weight > stride) {
remain.weight -= stride;
return final_server;
} else {
stride -= remain.weight;
}
}
remain.weight = 0;
++tls.position;
tls.position %= server_list.size();
}
while (stride > 0) {
uint32_t configured_weight = server_list[tls.position].weight;
final_server = server_list[tls.position].id;
if (configured_weight > stride) {
tls.remain_server.id = final_server;
tls.remain_server.weight = configured_weight - stride;
return final_server;
if (filter.count(final_server) == 0) {
uint32_t configured_weight = server_list[tls.position].weight;
if (configured_weight > stride) {
remain.id = final_server;
remain.weight = configured_weight - stride;
return final_server;
}
stride -= configured_weight;
}
stride -= configured_weight;
++tls.position;
tls.position %= server_list.size();
tls.position %= server_list.size();
}
return final_server;
}
......
......@@ -19,6 +19,7 @@
#include <map>
#include <vector>
#include <unordered_set>
#include "butil/containers/doubly_buffered_data.h"
#include "brpc/load_balancer.h"
......@@ -75,7 +76,8 @@ private:
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static SocketId GetServerInNextStride(const std::vector<Server>& server_list,
TLS& tls);
const std::unordered_set<SocketId>& filter,
TLS& tls);
butil::DoublyBufferedData<Servers, TLS> _db_servers;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment