Commit f375e9b6 authored by root's avatar root

fix for code review comments

parent 54df0863
...@@ -21,8 +21,7 @@ ...@@ -21,8 +21,7 @@
#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "brpc/policy/weighted_round_robin_load_balancer.h"
#include "butil/strings/string_number_conversions.h" #include "butil/strings/string_number_conversions.h"
namespace brpc { namespace {
namespace policy {
const std::vector<uint32_t> prime_stride = { const std::vector<uint32_t> prime_stride = {
2,3,5,11,17,29,47,71,107,137,163,251,307,379,569,683, 2,3,5,11,17,29,47,71,107,137,163,251,307,379,569,683,
...@@ -48,6 +47,24 @@ bool IsCoprime(uint32_t num1, uint32_t num2) { ...@@ -48,6 +47,24 @@ bool IsCoprime(uint32_t num1, uint32_t num2) {
return num2 == 1; return num2 == 1;
} }
// Get a reasonable stride according to weights configured of servers.
uint32_t GetStride(const uint32_t weight_sum, const uint32_t num) {
uint32_t average_weight = weight_sum / num;
auto iter = std::lower_bound(prime_stride.begin(), prime_stride.end(),
average_weight);
while (iter != prime_stride.end()
&& !IsCoprime(weight_sum, *iter)) {
++iter;
}
CHECK(iter != prime_stride.end()) << "Failed to get stride";
return *iter > weight_sum ? *iter % weight_sum : *iter;
}
} // namespace
namespace brpc {
namespace policy {
bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) { bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) {
if (bg.server_list.capacity() < 128) { if (bg.server_list.capacity() < 128) {
bg.server_list.reserve(128); bg.server_list.reserve(128);
...@@ -143,12 +160,12 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -143,12 +160,12 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
// If server list changed, the position may be out of range. // If server list changed, the position may be out of range.
tls.position %= s->server_list.size(); tls.position %= s->server_list.size();
// Check whether remain server was removed from server list. // Check whether remain server was removed from server list.
if (tls.HasRemainServer() && if (tls.remain_server.weight > 0 &&
tls.remain_server.id != s->server_list[tls.position].id) { tls.remain_server.id != s->server_list[tls.position].id) {
tls.ResetRemainServer(); tls.remain_server.weight = 0;
} }
for (uint32_t i = 0; i != tls.stride; ++i) { for (uint32_t i = 0; i != tls.stride; ++i) {
int64_t server_id = GetServerInNextStride(s->server_list, tls); SocketId server_id = GetServerInNextStride(s->server_list, tls);
if (!ExcludedServers::IsExcluded(in.excluded, server_id) if (!ExcludedServers::IsExcluded(in.excluded, server_id)
&& Socket::Address(server_id, out->ptr) == 0 && Socket::Address(server_id, out->ptr) == 0
&& !(*out->ptr)->IsLogOff()) { && !(*out->ptr)->IsLogOff()) {
...@@ -160,45 +177,35 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -160,45 +177,35 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
int64_t WeightedRoundRobinLoadBalancer::GetServerInNextStride( int64_t WeightedRoundRobinLoadBalancer::GetServerInNextStride(
const std::vector<Server>& server_list, TLS& tls) { const std::vector<Server>& server_list, TLS& tls) {
int64_t final_server = -1; SocketId final_server = 0;
int stride = tls.stride; int stride = tls.stride;
if (tls.HasRemainServer()) { if (tls.remain_server.weight > 0) {
final_server = tls.remain_server.id; final_server = tls.remain_server.id;
if (tls.remain_server.weight > stride) { if (tls.remain_server.weight > stride) {
tls.remain_server.weight -= stride; tls.remain_server.weight -= stride;
return final_server; return final_server;
} else { } else {
stride -= tls.remain_server.weight; stride -= tls.remain_server.weight;
tls.ResetRemainServer(); tls.remain_server.weight = 0;
tls.UpdatePosition(server_list.size()); ++tls.position;
tls.position %= server_list.size();
} }
} }
while (stride > 0) { while (stride > 0) {
int configured_weight = server_list[tls.position].weight;
final_server = server_list[tls.position].id; final_server = server_list[tls.position].id;
if (server_list[tls.position].weight > stride) { if (configured_weight > stride) {
tls.SetRemainServer(server_list[tls.position].id, tls.remain_server.id = final_server;
server_list[tls.position].weight - stride); tls.remain_server.weight = configured_weight - stride;
return final_server; return final_server;
} }
stride -= server_list[tls.position].weight; stride -= configured_weight;
tls.UpdatePosition(server_list.size()); ++tls.position;
tls.position %= server_list.size();
} }
return final_server; return final_server;
} }
uint32_t WeightedRoundRobinLoadBalancer::GetStride(
const uint32_t weight_sum, const uint32_t num) {
uint32_t average_weight = weight_sum / num;
auto iter = std::lower_bound(prime_stride.begin(), prime_stride.end(),
average_weight);
while (iter != prime_stride.end()
&& !IsCoprime(weight_sum, *iter)) {
++iter;
}
CHECK(iter != prime_stride.end()) << "Failed to get stride";
return *iter > weight_sum ? *iter % weight_sum : *iter;
}
LoadBalancer* WeightedRoundRobinLoadBalancer::New() const { LoadBalancer* WeightedRoundRobinLoadBalancer::New() const {
return new (std::nothrow) WeightedRoundRobinLoadBalancer; return new (std::nothrow) WeightedRoundRobinLoadBalancer;
} }
......
...@@ -55,20 +55,6 @@ private: ...@@ -55,20 +55,6 @@ private:
uint32_t position = 0; uint32_t position = 0;
uint32_t stride = 0; uint32_t stride = 0;
Server remain_server; Server remain_server;
bool HasRemainServer() const {
return remain_server.weight != 0;
}
void SetRemainServer(const SocketId id, const int weight) {
remain_server.id = id;
remain_server.weight = weight;
}
void ResetRemainServer() {
remain_server.weight = 0;
}
void UpdatePosition(const uint32_t size) {
++position;
position %= size;
}
// If server list changed, we need caculate a new stride. // If server list changed, we need caculate a new stride.
bool IsNeededCaculateNewStride(const uint32_t curr_weight_sum, bool IsNeededCaculateNewStride(const uint32_t curr_weight_sum,
const uint32_t curr_servers_num) { const uint32_t curr_servers_num) {
...@@ -90,8 +76,6 @@ private: ...@@ -90,8 +76,6 @@ private:
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static int64_t GetServerInNextStride(const std::vector<Server>& server_list, static int64_t GetServerInNextStride(const std::vector<Server>& server_list,
TLS& tls); TLS& tls);
// Get a reasonable stride according to weights configured of servers.
static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num);
butil::DoublyBufferedData<Servers, TLS> _db_servers; butil::DoublyBufferedData<Servers, TLS> _db_servers;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment