Commit 54df0863 authored by root's avatar root

fix for code review comments

parent d4432f32
// Copyright (c) 2018 Iqiyi, Inc. // Copyright (c) 2018 Iqiyi, Inc.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Authors: Daojin Cai (caidaojin@qiyi.com) // Authors: Daojin Cai (caidaojin@qiyi.com)
#include "butil/fast_rand.h" #include <algorithm>
#include "brpc/socket.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "butil/fast_rand.h"
#include "butil/strings/string_number_conversions.h" #include "brpc/socket.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h"
namespace brpc { #include "butil/strings/string_number_conversions.h"
namespace policy {
namespace brpc {
bool IsCoprime(uint32_t num1, uint32_t num2) { namespace policy {
uint32_t temp;
if (num1 < num2) { const std::vector<uint32_t> prime_stride = {
temp = num1; 2,3,5,11,17,29,47,71,107,137,163,251,307,379,569,683,
num1 = num2; 857,1289,1543,1949,2617,2927,3407,4391,6599,9901,14867,
num2 = temp; 22303,33457,50207,75323,112997,169501,254257,381389,572087};
}
while (true) { bool IsCoprime(uint32_t num1, uint32_t num2) {
temp = num1 % num2; uint32_t temp;
if (temp == 0) { if (num1 < num2) {
break; temp = num1;
} else { num1 = num2;
num1 = num2; num2 = temp;
num2 = temp; }
} while (true) {
} temp = num1 % num2;
return num2 == 1; if (temp == 0) {
} break;
} else {
bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) { num1 = num2;
if (bg.server_list.capacity() < 128) { num2 = temp;
bg.server_list.reserve(128); }
} }
int weight = 0; return num2 == 1;
if (butil::StringToInt(id.tag, &weight) && weight > 0) { }
bool insert_server =
bg.server_map.emplace(id.id, bg.server_list.size()).second; bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) {
if (insert_server) { if (bg.server_list.capacity() < 128) {
bg.server_list.emplace_back(id.id, weight); bg.server_list.reserve(128);
bg.weight_sum += weight; }
return true; int weight = 0;
} if (butil::StringToInt(id.tag, &weight) && weight > 0) {
} else { bool insert_server =
LOG(ERROR) << "Invalid weight is set: " << id.tag; bg.server_map.emplace(id.id, bg.server_list.size()).second;
} if (insert_server) {
return false; bg.server_list.emplace_back(id.id, weight);
} bg.weight_sum += weight;
return true;
bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) { }
auto iter = bg.server_map.find(id.id); } else {
if (iter != bg.server_map.end()) { LOG(ERROR) << "Invalid weight is set: " << id.tag;
const size_t index = iter->second; }
bg.weight_sum -= bg.server_list[index].second; return false;
bg.server_list[index] = bg.server_list.back(); }
bg.server_map[bg.server_list[index].first] = index;
bg.server_list.pop_back(); bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) {
bg.server_map.erase(iter); auto iter = bg.server_map.find(id.id);
return true; if (iter != bg.server_map.end()) {
} const size_t index = iter->second;
return false; bg.weight_sum -= bg.server_list[index].weight;
} bg.server_list[index] = bg.server_list.back();
bg.server_map[bg.server_list[index].id] = index;
size_t WeightedRoundRobinLoadBalancer::BatchAdd( bg.server_list.pop_back();
Servers& bg, const std::vector<ServerId>& servers) { bg.server_map.erase(iter);
size_t count = 0; return true;
for (size_t i = 0; i < servers.size(); ++i) { }
count += !!Add(bg, servers[i]); return false;
} }
return count;
} size_t WeightedRoundRobinLoadBalancer::BatchAdd(
Servers& bg, const std::vector<ServerId>& servers) {
size_t WeightedRoundRobinLoadBalancer::BatchRemove( size_t count = 0;
Servers& bg, const std::vector<ServerId>& servers) { for (size_t i = 0; i < servers.size(); ++i) {
size_t count = 0; count += !!Add(bg, servers[i]);
for (size_t i = 0; i < servers.size(); ++i) { }
count += !!Remove(bg, servers[i]); return count;
} }
return count;
} size_t WeightedRoundRobinLoadBalancer::BatchRemove(
Servers& bg, const std::vector<ServerId>& servers) {
bool WeightedRoundRobinLoadBalancer::AddServer(const ServerId& id) { size_t count = 0;
return _db_servers.Modify(Add, id); for (size_t i = 0; i < servers.size(); ++i) {
} count += !!Remove(bg, servers[i]);
}
bool WeightedRoundRobinLoadBalancer::RemoveServer(const ServerId& id) { return count;
return _db_servers.Modify(Remove, id); }
}
bool WeightedRoundRobinLoadBalancer::AddServer(const ServerId& id) {
size_t WeightedRoundRobinLoadBalancer::AddServersInBatch( return _db_servers.Modify(Add, id);
const std::vector<ServerId>& servers) { }
const size_t n = _db_servers.Modify(BatchAdd, servers);
LOG_IF(ERROR, n != servers.size()) bool WeightedRoundRobinLoadBalancer::RemoveServer(const ServerId& id) {
<< "Fail to AddServersInBatch, expected " << servers.size() return _db_servers.Modify(Remove, id);
<< " actually " << n; }
return n;
} size_t WeightedRoundRobinLoadBalancer::AddServersInBatch(
const std::vector<ServerId>& servers) {
size_t WeightedRoundRobinLoadBalancer::RemoveServersInBatch( const size_t n = _db_servers.Modify(BatchAdd, servers);
const std::vector<ServerId>& servers) { LOG_IF(ERROR, n != servers.size())
const size_t n = _db_servers.Modify(BatchRemove, servers); << "Fail to AddServersInBatch, expected " << servers.size()
LOG_IF(ERROR, n != servers.size()) << " actually " << n;
<< "Fail to RemoveServersInBatch, expected " << servers.size() return n;
<< " actually " << n; }
return n;
} size_t WeightedRoundRobinLoadBalancer::RemoveServersInBatch(
const std::vector<ServerId>& servers) {
int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) { const size_t n = _db_servers.Modify(BatchRemove, servers);
butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s; LOG_IF(ERROR, n != servers.size())
if (_db_servers.Read(&s) != 0) { << "Fail to RemoveServersInBatch, expected " << servers.size()
return ENOMEM; << " actually " << n;
} return n;
if (s->server_list.empty()) { }
return ENODATA;
} int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) {
TLS& tls = s.tls(); butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) { if (_db_servers.Read(&s) != 0) {
if (tls.stride == 0) { return ENOMEM;
tls.position = butil::fast_rand_less_than(s->server_list.size()); }
} if (s->server_list.empty()) {
tls.stride = GetStride(s->weight_sum, s->server_list.size()); return ENODATA;
} }
// If server list changed, the position may be out of range. TLS& tls = s.tls();
tls.position %= s->server_list.size(); if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) {
// Check whether remain server was removed from server list. if (tls.stride == 0) {
if (tls.HasRemainServer() && tls.position = butil::fast_rand_less_than(s->server_list.size());
s->server_map.find(tls.remain_server.first) == s->server_map.end()) { }
tls.ResetRemainServer(); tls.stride = GetStride(s->weight_sum, s->server_list.size());
} }
for ( uint32_t i = 0; i != tls.stride; ++i) { // If server list changed, the position may be out of range.
int64_t best = GetBestServer(s->server_list, tls); tls.position %= s->server_list.size();
if (!ExcludedServers::IsExcluded(in.excluded, best) // Check whether remain server was removed from server list.
&& Socket::Address(best, out->ptr) == 0 if (tls.HasRemainServer() &&
&& !(*out->ptr)->IsLogOff()) { tls.remain_server.id != s->server_list[tls.position].id) {
return 0; tls.ResetRemainServer();
} }
} for (uint32_t i = 0; i != tls.stride; ++i) {
return EHOSTDOWN; int64_t server_id = GetServerInNextStride(s->server_list, tls);
} if (!ExcludedServers::IsExcluded(in.excluded, server_id)
&& Socket::Address(server_id, out->ptr) == 0
int64_t WeightedRoundRobinLoadBalancer::GetBestServer( && !(*out->ptr)->IsLogOff()) {
const std::vector<std::pair<SocketId, int>>& server_list, TLS& tls) { return 0;
int64_t final_server = -1; }
int stride = tls.stride; }
int weight = 0; return EHOSTDOWN;
while (stride > 0) { }
if (tls.HasRemainServer()) {
weight = tls.remain_server.second; int64_t WeightedRoundRobinLoadBalancer::GetServerInNextStride(
if (weight <= stride) { const std::vector<Server>& server_list, TLS& tls) {
tls.ResetRemainServer(); int64_t final_server = -1;
} else { int stride = tls.stride;
tls.remain_server.second -= stride; if (tls.HasRemainServer()) {
} final_server = tls.remain_server.id;
} else { if (tls.remain_server.weight > stride) {
weight = server_list[tls.position].second; tls.remain_server.weight -= stride;
if (weight > stride) { return final_server;
tls.SetRemainServer(server_list[tls.position].first, } else {
weight - stride); stride -= tls.remain_server.weight;
} tls.ResetRemainServer();
tls.UpdatePosition(server_list.size()); tls.UpdatePosition(server_list.size());
} }
stride -= weight; }
} while (stride > 0) {
if (tls.HasRemainServer()) { final_server = server_list[tls.position].id;
final_server = tls.remain_server.first; if (server_list[tls.position].weight > stride) {
} else { tls.SetRemainServer(server_list[tls.position].id,
size_t index = tls.position == 0 ? server_list.size() - 1 server_list[tls.position].weight - stride);
: tls.position - 1; return final_server;
final_server = server_list[index].first; }
} stride -= server_list[tls.position].weight;
return final_server; tls.UpdatePosition(server_list.size());
} }
return final_server;
uint32_t WeightedRoundRobinLoadBalancer::GetStride( }
const uint32_t weight_sum, const uint32_t num) {
uint32_t average_weight = weight_sum / num; uint32_t WeightedRoundRobinLoadBalancer::GetStride(
// The stride is the first number which is greater than or equal to const uint32_t weight_sum, const uint32_t num) {
// average weight and coprime to weight_sum. uint32_t average_weight = weight_sum / num;
while (!IsCoprime(weight_sum, average_weight)) { auto iter = std::lower_bound(prime_stride.begin(), prime_stride.end(),
++average_weight; average_weight);
} while (iter != prime_stride.end()
return average_weight; && !IsCoprime(weight_sum, *iter)) {
} ++iter;
}
LoadBalancer* WeightedRoundRobinLoadBalancer::New() const { CHECK(iter != prime_stride.end()) << "Failed to get stride";
return new (std::nothrow) WeightedRoundRobinLoadBalancer; return *iter > weight_sum ? *iter % weight_sum : *iter;
} }
void WeightedRoundRobinLoadBalancer::Destroy() { LoadBalancer* WeightedRoundRobinLoadBalancer::New() const {
delete this; return new (std::nothrow) WeightedRoundRobinLoadBalancer;
} }
void WeightedRoundRobinLoadBalancer::Describe( void WeightedRoundRobinLoadBalancer::Destroy() {
std::ostream &os, const DescribeOptions& options) { delete this;
if (!options.verbose) { }
os << "wrr";
return; void WeightedRoundRobinLoadBalancer::Describe(
} std::ostream &os, const DescribeOptions& options) {
os << "WeightedRoundRobin{"; if (!options.verbose) {
butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s; os << "wrr";
if (_db_servers.Read(&s) != 0) { return;
os << "fail to read _db_servers"; }
} else { os << "WeightedRoundRobin{";
os << "n=" << s->server_list.size() << ':'; butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
for (const auto& server : s->server_list) { if (_db_servers.Read(&s) != 0) {
os << ' ' << server.first << '(' << server.second << ')'; os << "fail to read _db_servers";
} } else {
} os << "n=" << s->server_list.size() << ':';
os << '}'; for (const auto& server : s->server_list) {
} os << ' ' << server.id << '(' << server.weight << ')';
}
} // namespace policy }
} // namespace brpc os << '}';
}
} // namespace policy
} // namespace brpc
...@@ -39,27 +39,31 @@ public: ...@@ -39,27 +39,31 @@ public:
void Describe(std::ostream&, const DescribeOptions& options); void Describe(std::ostream&, const DescribeOptions& options);
private: private:
struct Server {
Server(SocketId s_id = 0, int s_w = 0): id(s_id), weight(s_w) {}
SocketId id;
int weight;
};
struct Servers { struct Servers {
// The value is configured weight for each server. // The value is configured weight for each server.
std::vector<std::pair<SocketId, int>> server_list; std::vector<Server> server_list;
// The value is the index of the server in "server_list". // The value is the index of the server in "server_list".
std::map<SocketId, size_t> server_map; std::map<SocketId, size_t> server_map;
uint32_t weight_sum = 0; uint32_t weight_sum = 0;
}; };
struct TLS { struct TLS {
TLS(): remain_server(0, 0) { }
uint32_t position = 0; uint32_t position = 0;
uint32_t stride = 0; uint32_t stride = 0;
std::pair<SocketId, int> remain_server; Server remain_server;
bool HasRemainServer() const { bool HasRemainServer() const {
return remain_server.second != 0; return remain_server.weight != 0;
} }
void SetRemainServer(const SocketId id, const int weight) { void SetRemainServer(const SocketId id, const int weight) {
remain_server.first = id; remain_server.id = id;
remain_server.second = weight; remain_server.weight = weight;
} }
void ResetRemainServer() { void ResetRemainServer() {
remain_server.second = 0; remain_server.weight = 0;
} }
void UpdatePosition(const uint32_t size) { void UpdatePosition(const uint32_t size) {
++position; ++position;
...@@ -84,9 +88,8 @@ private: ...@@ -84,9 +88,8 @@ private:
static bool Remove(Servers& bg, const ServerId& id); static bool Remove(Servers& bg, const ServerId& id);
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static int64_t GetBestServer( static int64_t GetServerInNextStride(const std::vector<Server>& server_list,
const std::vector<std::pair<SocketId, int>>& server_list, TLS& tls);
TLS& tls);
// Get a reasonable stride according to weights configured of servers. // Get a reasonable stride according to weights configured of servers.
static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num); static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment