Commit 54df0863 authored by root's avatar root

fix for code review comments

parent d4432f32
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
// Authors: Daojin Cai (caidaojin@qiyi.com) // Authors: Daojin Cai (caidaojin@qiyi.com)
#include <algorithm>
#include "butil/fast_rand.h" #include "butil/fast_rand.h"
#include "brpc/socket.h" #include "brpc/socket.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "brpc/policy/weighted_round_robin_load_balancer.h"
...@@ -22,6 +24,11 @@ ...@@ -22,6 +24,11 @@
namespace brpc { namespace brpc {
namespace policy { namespace policy {
const std::vector<uint32_t> prime_stride = {
2,3,5,11,17,29,47,71,107,137,163,251,307,379,569,683,
857,1289,1543,1949,2617,2927,3407,4391,6599,9901,14867,
22303,33457,50207,75323,112997,169501,254257,381389,572087};
bool IsCoprime(uint32_t num1, uint32_t num2) { bool IsCoprime(uint32_t num1, uint32_t num2) {
uint32_t temp; uint32_t temp;
if (num1 < num2) { if (num1 < num2) {
...@@ -64,9 +71,9 @@ bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) { ...@@ -64,9 +71,9 @@ bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) {
auto iter = bg.server_map.find(id.id); auto iter = bg.server_map.find(id.id);
if (iter != bg.server_map.end()) { if (iter != bg.server_map.end()) {
const size_t index = iter->second; const size_t index = iter->second;
bg.weight_sum -= bg.server_list[index].second; bg.weight_sum -= bg.server_list[index].weight;
bg.server_list[index] = bg.server_list.back(); bg.server_list[index] = bg.server_list.back();
bg.server_map[bg.server_list[index].first] = index; bg.server_map[bg.server_list[index].id] = index;
bg.server_list.pop_back(); bg.server_list.pop_back();
bg.server_map.erase(iter); bg.server_map.erase(iter);
return true; return true;
...@@ -137,13 +144,13 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -137,13 +144,13 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
tls.position %= s->server_list.size(); tls.position %= s->server_list.size();
// Check whether remain server was removed from server list. // Check whether remain server was removed from server list.
if (tls.HasRemainServer() && if (tls.HasRemainServer() &&
s->server_map.find(tls.remain_server.first) == s->server_map.end()) { tls.remain_server.id != s->server_list[tls.position].id) {
tls.ResetRemainServer(); tls.ResetRemainServer();
} }
for ( uint32_t i = 0; i != tls.stride; ++i) { for (uint32_t i = 0; i != tls.stride; ++i) {
int64_t best = GetBestServer(s->server_list, tls); int64_t server_id = GetServerInNextStride(s->server_list, tls);
if (!ExcludedServers::IsExcluded(in.excluded, best) if (!ExcludedServers::IsExcluded(in.excluded, server_id)
&& Socket::Address(best, out->ptr) == 0 && Socket::Address(server_id, out->ptr) == 0
&& !(*out->ptr)->IsLogOff()) { && !(*out->ptr)->IsLogOff()) {
return 0; return 0;
} }
...@@ -151,35 +158,30 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* ...@@ -151,35 +158,30 @@ int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut*
return EHOSTDOWN; return EHOSTDOWN;
} }
int64_t WeightedRoundRobinLoadBalancer::GetBestServer( int64_t WeightedRoundRobinLoadBalancer::GetServerInNextStride(
const std::vector<std::pair<SocketId, int>>& server_list, TLS& tls) { const std::vector<Server>& server_list, TLS& tls) {
int64_t final_server = -1; int64_t final_server = -1;
int stride = tls.stride; int stride = tls.stride;
int weight = 0;
while (stride > 0) {
if (tls.HasRemainServer()) { if (tls.HasRemainServer()) {
weight = tls.remain_server.second; final_server = tls.remain_server.id;
if (weight <= stride) { if (tls.remain_server.weight > stride) {
tls.ResetRemainServer(); tls.remain_server.weight -= stride;
} else { return final_server;
tls.remain_server.second -= stride;
}
} else { } else {
weight = server_list[tls.position].second; stride -= tls.remain_server.weight;
if (weight > stride) { tls.ResetRemainServer();
tls.SetRemainServer(server_list[tls.position].first,
weight - stride);
}
tls.UpdatePosition(server_list.size()); tls.UpdatePosition(server_list.size());
} }
stride -= weight;
} }
if (tls.HasRemainServer()) { while (stride > 0) {
final_server = tls.remain_server.first; final_server = server_list[tls.position].id;
} else { if (server_list[tls.position].weight > stride) {
size_t index = tls.position == 0 ? server_list.size() - 1 tls.SetRemainServer(server_list[tls.position].id,
: tls.position - 1; server_list[tls.position].weight - stride);
final_server = server_list[index].first; return final_server;
}
stride -= server_list[tls.position].weight;
tls.UpdatePosition(server_list.size());
} }
return final_server; return final_server;
} }
...@@ -187,12 +189,14 @@ int64_t WeightedRoundRobinLoadBalancer::GetBestServer( ...@@ -187,12 +189,14 @@ int64_t WeightedRoundRobinLoadBalancer::GetBestServer(
uint32_t WeightedRoundRobinLoadBalancer::GetStride( uint32_t WeightedRoundRobinLoadBalancer::GetStride(
const uint32_t weight_sum, const uint32_t num) { const uint32_t weight_sum, const uint32_t num) {
uint32_t average_weight = weight_sum / num; uint32_t average_weight = weight_sum / num;
// The stride is the first number which is greater than or equal to auto iter = std::lower_bound(prime_stride.begin(), prime_stride.end(),
// average weight and coprime to weight_sum. average_weight);
while (!IsCoprime(weight_sum, average_weight)) { while (iter != prime_stride.end()
++average_weight; && !IsCoprime(weight_sum, *iter)) {
} ++iter;
return average_weight; }
CHECK(iter != prime_stride.end()) << "Failed to get stride";
return *iter > weight_sum ? *iter % weight_sum : *iter;
} }
LoadBalancer* WeightedRoundRobinLoadBalancer::New() const { LoadBalancer* WeightedRoundRobinLoadBalancer::New() const {
...@@ -216,7 +220,7 @@ void WeightedRoundRobinLoadBalancer::Describe( ...@@ -216,7 +220,7 @@ void WeightedRoundRobinLoadBalancer::Describe(
} else { } else {
os << "n=" << s->server_list.size() << ':'; os << "n=" << s->server_list.size() << ':';
for (const auto& server : s->server_list) { for (const auto& server : s->server_list) {
os << ' ' << server.first << '(' << server.second << ')'; os << ' ' << server.id << '(' << server.weight << ')';
} }
} }
os << '}'; os << '}';
......
...@@ -39,27 +39,31 @@ public: ...@@ -39,27 +39,31 @@ public:
void Describe(std::ostream&, const DescribeOptions& options); void Describe(std::ostream&, const DescribeOptions& options);
private: private:
struct Server {
Server(SocketId s_id = 0, int s_w = 0): id(s_id), weight(s_w) {}
SocketId id;
int weight;
};
struct Servers { struct Servers {
// The value is configured weight for each server. // The value is configured weight for each server.
std::vector<std::pair<SocketId, int>> server_list; std::vector<Server> server_list;
// The value is the index of the server in "server_list". // The value is the index of the server in "server_list".
std::map<SocketId, size_t> server_map; std::map<SocketId, size_t> server_map;
uint32_t weight_sum = 0; uint32_t weight_sum = 0;
}; };
struct TLS { struct TLS {
TLS(): remain_server(0, 0) { }
uint32_t position = 0; uint32_t position = 0;
uint32_t stride = 0; uint32_t stride = 0;
std::pair<SocketId, int> remain_server; Server remain_server;
bool HasRemainServer() const { bool HasRemainServer() const {
return remain_server.second != 0; return remain_server.weight != 0;
} }
void SetRemainServer(const SocketId id, const int weight) { void SetRemainServer(const SocketId id, const int weight) {
remain_server.first = id; remain_server.id = id;
remain_server.second = weight; remain_server.weight = weight;
} }
void ResetRemainServer() { void ResetRemainServer() {
remain_server.second = 0; remain_server.weight = 0;
} }
void UpdatePosition(const uint32_t size) { void UpdatePosition(const uint32_t size) {
++position; ++position;
...@@ -84,8 +88,7 @@ private: ...@@ -84,8 +88,7 @@ private:
static bool Remove(Servers& bg, const ServerId& id); static bool Remove(Servers& bg, const ServerId& id);
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static int64_t GetBestServer( static int64_t GetServerInNextStride(const std::vector<Server>& server_list,
const std::vector<std::pair<SocketId, int>>& server_list,
TLS& tls); TLS& tls);
// Get a reasonable stride according to weights configured of servers. // Get a reasonable stride according to weights configured of servers.
static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num); static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment