Commit 977d26e0 authored by cdjgit's avatar cdjgit

enhance wrr lb algo

parent c7cbfb37
...@@ -171,44 +171,69 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -171,44 +171,69 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
tls.remain_server.id != s->server_list[tls.position].id) { tls.remain_server.id != s->server_list[tls.position].id) {
tls.remain_server.weight = 0; tls.remain_server.weight = 0;
} }
for (uint64_t i = 0; i != tls.stride; ++i) { // The servers that can not be choosed.
SocketId server_id = GetServerInNextStride(s->server_list, tls); std::unordered_set<SocketId> filter;
TLS tls_temp = tls;
uint64_t remain_weight = s->weight_sum;
size_t remain_servers = s->server_list.size();
while (remain_servers > 0) {
SocketId server_id = GetServerInNextStride(s->server_list, filter, tls_temp);
if (!ExcludedServers::IsExcluded(in.excluded, server_id) if (!ExcludedServers::IsExcluded(in.excluded, server_id)
&& Socket::Address(server_id, out->ptr) == 0 && Socket::Address(server_id, out->ptr) == 0
&& !(*out->ptr)->IsLogOff()) { && !(*out->ptr)->IsLogOff()) {
// update tls.
tls.remain_server = tls_temp.remain_server;
tls.position = tls_temp.position;
return 0; return 0;
} else {
// Skip this invalid server. We need calculate a new stride for server selection.
filter.emplace(server_id);
remain_weight -= (s->server_list[s->server_map.at(server_id)]).weight;
--remain_servers;
// Select from begining status.
tls_temp.stride = GetStride(remain_weight, remain_servers);
tls_temp.position = tls.position;
tls_temp.remain_server = tls.remain_server;
continue;
} }
} }
return EHOSTDOWN; return EHOSTDOWN;
} }
SocketId WeightedRoundRobinLoadBalancer::GetServerInNextStride( SocketId WeightedRoundRobinLoadBalancer::GetServerInNextStride(
const std::vector<Server>& server_list, TLS& tls) { const std::vector<Server>& server_list,
SocketId final_server = 0; const std::unordered_set<SocketId>& filter,
TLS& tls) {
SocketId final_server = INVALID_STREAM_ID;
uint64_t stride = tls.stride; uint64_t stride = tls.stride;
if (tls.remain_server.weight > 0) { Server& remain = tls.remain_server;
final_server = tls.remain_server.id; if (remain.weight > 0) {
if (tls.remain_server.weight > stride) { if (filter.count(remain.id) == 0) {
tls.remain_server.weight -= stride; final_server = remain.id;
return final_server; if (remain.weight > stride) {
} else { remain.weight -= stride;
stride -= tls.remain_server.weight; return final_server;
tls.remain_server.weight = 0; } else {
++tls.position; stride -= remain.weight;
tls.position %= server_list.size(); }
} }
remain.weight = 0;
++tls.position;
tls.position %= server_list.size();
} }
while (stride > 0) { while (stride > 0) {
uint32_t configured_weight = server_list[tls.position].weight;
final_server = server_list[tls.position].id; final_server = server_list[tls.position].id;
if (configured_weight > stride) { if (filter.count(final_server) == 0) {
tls.remain_server.id = final_server; uint32_t configured_weight = server_list[tls.position].weight;
tls.remain_server.weight = configured_weight - stride; if (configured_weight > stride) {
return final_server; remain.id = final_server;
remain.weight = configured_weight - stride;
return final_server;
}
stride -= configured_weight;
} }
stride -= configured_weight;
++tls.position; ++tls.position;
tls.position %= server_list.size(); tls.position %= server_list.size();
} }
return final_server; return final_server;
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <map> #include <map>
#include <vector> #include <vector>
#include <unordered_set>
#include "butil/containers/doubly_buffered_data.h" #include "butil/containers/doubly_buffered_data.h"
#include "brpc/load_balancer.h" #include "brpc/load_balancer.h"
...@@ -75,7 +76,8 @@ private: ...@@ -75,7 +76,8 @@ private:
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static SocketId GetServerInNextStride(const std::vector<Server>& server_list, static SocketId GetServerInNextStride(const std::vector<Server>& server_list,
TLS& tls); const std::unordered_set<SocketId>& filter,
TLS& tls);
butil::DoublyBufferedData<Servers, TLS> _db_servers; butil::DoublyBufferedData<Servers, TLS> _db_servers;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment