Commit 3690d3ba authored by root's avatar root

little change for algorithm: each tls has the same stride, but a different…

little change for algorithm: each tls has the same stride, but a different random beginning position of server list.
parent fdc19fe2
// Copyright (c) 2018 Iqiyi, Inc. // Copyright (c) 2018 Iqiyi, Inc.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Authors: Daojin Cai (caidaojin@qiyi.com) // Authors: Daojin Cai (caidaojin@qiyi.com)
#include "butil/fast_rand.h" #include "butil/fast_rand.h"
#include "brpc/socket.h" #include "brpc/socket.h"
#include "brpc/policy/weighted_round_robin_load_balancer.h" #include "brpc/policy/weighted_round_robin_load_balancer.h"
#include "butil/strings/string_number_conversions.h" #include "butil/strings/string_number_conversions.h"
namespace brpc { namespace brpc {
namespace policy { namespace policy {
bool IsCoprime(uint32_t num1, uint32_t num2) { bool IsCoprime(uint32_t num1, uint32_t num2) {
uint32_t temp; uint32_t temp;
if (num1 < num2) { if (num1 < num2) {
temp = num1; temp = num1;
num1 = num2; num1 = num2;
num2 = temp; num2 = temp;
} }
while (true) { while (true) {
temp = num1 % num2; temp = num1 % num2;
if (temp == 0) { if (temp == 0) {
break; break;
} else { } else {
num1 = num2; num1 = num2;
num2 = temp; num2 = temp;
} }
} }
return num2 == 1; return num2 == 1;
} }
bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) { bool WeightedRoundRobinLoadBalancer::Add(Servers& bg, const ServerId& id) {
if (bg.server_list.capacity() < 128) { if (bg.server_list.capacity() < 128) {
bg.server_list.reserve(128); bg.server_list.reserve(128);
} }
int weight = 0; int weight = 0;
if (butil::StringToInt(id.tag, &weight) && weight > 0) { if (butil::StringToInt(id.tag, &weight) && weight > 0) {
bool insert_server = bool insert_server =
bg.server_map.emplace(id.id, bg.server_list.size()).second; bg.server_map.emplace(id.id, bg.server_list.size()).second;
if (insert_server) { if (insert_server) {
bg.server_list.emplace_back(id.id, weight); bg.server_list.emplace_back(id.id, weight);
bg.weight_sum += weight; bg.weight_sum += weight;
return true; return true;
} }
} else { } else {
LOG(ERROR) << "Invalid weight is set: " << id.tag; LOG(ERROR) << "Invalid weight is set: " << id.tag;
} }
return false; return false;
} }
bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) { bool WeightedRoundRobinLoadBalancer::Remove(Servers& bg, const ServerId& id) {
auto iter = bg.server_map.find(id.id); auto iter = bg.server_map.find(id.id);
if (iter != bg.server_map.end()) { if (iter != bg.server_map.end()) {
const size_t index = iter->second; const size_t index = iter->second;
bg.weight_sum -= bg.server_list[index].second; bg.weight_sum -= bg.server_list[index].second;
bg.server_list[index] = bg.server_list.back(); bg.server_list[index] = bg.server_list.back();
bg.server_map[bg.server_list[index].first] = index; bg.server_map[bg.server_list[index].first] = index;
bg.server_list.pop_back(); bg.server_list.pop_back();
bg.server_map.erase(iter); bg.server_map.erase(iter);
return true; return true;
} }
return false; return false;
} }
size_t WeightedRoundRobinLoadBalancer::BatchAdd( size_t WeightedRoundRobinLoadBalancer::BatchAdd(
Servers& bg, const std::vector<ServerId>& servers) { Servers& bg, const std::vector<ServerId>& servers) {
size_t count = 0; size_t count = 0;
for (size_t i = 0; i < servers.size(); ++i) { for (size_t i = 0; i < servers.size(); ++i) {
count += !!Add(bg, servers[i]); count += !!Add(bg, servers[i]);
} }
return count; return count;
} }
size_t WeightedRoundRobinLoadBalancer::BatchRemove( size_t WeightedRoundRobinLoadBalancer::BatchRemove(
Servers& bg, const std::vector<ServerId>& servers) { Servers& bg, const std::vector<ServerId>& servers) {
size_t count = 0; size_t count = 0;
for (size_t i = 0; i < servers.size(); ++i) { for (size_t i = 0; i < servers.size(); ++i) {
count += !!Remove(bg, servers[i]); count += !!Remove(bg, servers[i]);
} }
return count; return count;
} }
bool WeightedRoundRobinLoadBalancer::AddServer(const ServerId& id) { bool WeightedRoundRobinLoadBalancer::AddServer(const ServerId& id) {
return _db_servers.Modify(Add, id); return _db_servers.Modify(Add, id);
} }
bool WeightedRoundRobinLoadBalancer::RemoveServer(const ServerId& id) { bool WeightedRoundRobinLoadBalancer::RemoveServer(const ServerId& id) {
return _db_servers.Modify(Remove, id); return _db_servers.Modify(Remove, id);
} }
size_t WeightedRoundRobinLoadBalancer::AddServersInBatch( size_t WeightedRoundRobinLoadBalancer::AddServersInBatch(
const std::vector<ServerId>& servers) { const std::vector<ServerId>& servers) {
const size_t n = _db_servers.Modify(BatchAdd, servers); const size_t n = _db_servers.Modify(BatchAdd, servers);
LOG_IF(ERROR, n != servers.size()) LOG_IF(ERROR, n != servers.size())
<< "Fail to AddServersInBatch, expected " << servers.size() << "Fail to AddServersInBatch, expected " << servers.size()
<< " actually " << n; << " actually " << n;
return n; return n;
} }
size_t WeightedRoundRobinLoadBalancer::RemoveServersInBatch( size_t WeightedRoundRobinLoadBalancer::RemoveServersInBatch(
const std::vector<ServerId>& servers) { const std::vector<ServerId>& servers) {
const size_t n = _db_servers.Modify(BatchRemove, servers); const size_t n = _db_servers.Modify(BatchRemove, servers);
LOG_IF(ERROR, n != servers.size()) LOG_IF(ERROR, n != servers.size())
<< "Fail to RemoveServersInBatch, expected " << servers.size() << "Fail to RemoveServersInBatch, expected " << servers.size()
<< " actually " << n; << " actually " << n;
return n; return n;
} }
int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) { int WeightedRoundRobinLoadBalancer::SelectServer(const SelectIn& in, SelectOut* out) {
butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s; butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
if (_db_servers.Read(&s) != 0) { if (_db_servers.Read(&s) != 0) {
return ENOMEM; return ENOMEM;
} }
if (s->server_list.empty()) { if (s->server_list.empty()) {
return ENODATA; return ENODATA;
} }
TLS& tls = s.tls(); TLS& tls = s.tls();
if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) { if (tls.IsNeededCaculateNewStride(s->weight_sum, s->server_list.size())) {
tls.stride = GetStride(s->weight_sum, s->server_list.size()); if (tls.stride == 0) {
tls.offset = butil::fast_rand_less_than(tls.stride); tls.position = butil::fast_rand_less_than(s->server_list.size());
} }
// If server list changed, the position may be out of range. tls.stride = GetStride(s->weight_sum, s->server_list.size());
tls.position %= s->server_list.size(); }
// Check whether remain server was removed from server list. // If server list changed, the position may be out of range.
if (tls.HasRemainServer() && tls.position %= s->server_list.size();
s->server_map.find(tls.remain_server.first) == s->server_map.end()) { // Check whether remain server was removed from server list.
tls.ResetRemainServer(); if (tls.HasRemainServer() &&
} s->server_map.find(tls.remain_server.first) == s->server_map.end()) {
for ( uint32_t i = 0; i != tls.stride; ++i) { tls.ResetRemainServer();
int64_t best = GetBestServer(s->server_list, tls); }
if (!ExcludedServers::IsExcluded(in.excluded, best) for ( uint32_t i = 0; i != tls.stride; ++i) {
&& Socket::Address(best, out->ptr) == 0 int64_t best = GetBestServer(s->server_list, tls);
&& !(*out->ptr)->IsLogOff()) { if (!ExcludedServers::IsExcluded(in.excluded, best)
return 0; && Socket::Address(best, out->ptr) == 0
} && !(*out->ptr)->IsLogOff()) {
} return 0;
return EHOSTDOWN; }
} }
return EHOSTDOWN;
int64_t WeightedRoundRobinLoadBalancer::GetBestServer( }
const std::vector<std::pair<SocketId, int>>& server_list,
TLS& tls) { int64_t WeightedRoundRobinLoadBalancer::GetBestServer(
uint32_t comp_weight = 0; const std::vector<std::pair<SocketId, int>>& server_list, TLS& tls) {
int64_t final_server = -1; int64_t final_server = -1;
int stride = tls.stride; int stride = tls.stride;
int weight = 0; int weight = 0;
while (stride > 0) { while (stride > 0) {
if (tls.HasRemainServer()) { if (tls.HasRemainServer()) {
weight = tls.remain_server.second; weight = tls.remain_server.second;
if (weight <= stride) { if (weight <= stride) {
TryToGetFinalServer(tls, tls.remain_server, tls.ResetRemainServer();
comp_weight, &final_server); } else {
tls.ResetRemainServer(); tls.remain_server.second -= stride;
} else { }
TryToGetFinalServer(tls, } else {
std::pair<SocketId, int>(tls.remain_server.first, stride), weight = server_list[tls.position].second;
comp_weight, &final_server); if (weight > stride) {
tls.remain_server.second -= stride; tls.SetRemainServer(server_list[tls.position].first,
} weight - stride);
} else { }
weight = server_list[tls.position].second; tls.UpdatePosition(server_list.size());
if (weight <= stride) { }
TryToGetFinalServer(tls, server_list[tls.position], stride -= weight;
comp_weight, &final_server); }
} else { if (tls.HasRemainServer()) {
TryToGetFinalServer(tls, final_server = tls.remain_server.first;
std::pair<SocketId, int>( } else {
server_list[tls.position].first, stride), final_server = tls.position == 0 ? server_list.size() -1
comp_weight, &final_server); : tls.position -1;
tls.SetRemainServer(server_list[tls.position].first, }
weight - stride); return final_server;
} }
tls.UpdatePosition(server_list.size());
} uint32_t WeightedRoundRobinLoadBalancer::GetStride(
stride -= weight; const uint32_t weight_sum, const uint32_t num) {
} uint32_t average_weight = weight_sum / num;
return final_server; // The stride is the first number which is greater than or equal to
} // average weight and coprime to weight_sum.
while (!IsCoprime(weight_sum, average_weight)) {
uint32_t WeightedRoundRobinLoadBalancer::GetStride( ++average_weight;
const uint32_t weight_sum, const uint32_t num) { }
uint32_t average_weight = weight_sum / num; return average_weight;
// The stride is the first number which is greater than or equal to }
// average weight and coprime to weight_sum.
while (!IsCoprime(weight_sum, average_weight)) { LoadBalancer* WeightedRoundRobinLoadBalancer::New() const {
++average_weight; return new (std::nothrow) WeightedRoundRobinLoadBalancer;
} }
return average_weight;
} void WeightedRoundRobinLoadBalancer::Destroy() {
delete this;
void WeightedRoundRobinLoadBalancer::TryToGetFinalServer( }
const TLS& tls, const std::pair<SocketId, int> server,
uint32_t& comp_weight, int64_t* final_server) { void WeightedRoundRobinLoadBalancer::Describe(
if (*final_server == -1) { std::ostream &os, const DescribeOptions& options) {
comp_weight += server.second; if (!options.verbose) {
if (comp_weight >= tls.offset) { os << "wrr";
*final_server = server.first; return;
} }
} os << "WeightedRoundRobin{";
} butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
if (_db_servers.Read(&s) != 0) {
LoadBalancer* WeightedRoundRobinLoadBalancer::New() const { os << "fail to read _db_servers";
return new (std::nothrow) WeightedRoundRobinLoadBalancer; } else {
} os << "n=" << s->server_list.size() << ':';
for (const auto& server : s->server_list) {
void WeightedRoundRobinLoadBalancer::Destroy() { os << ' ' << server.first << '(' << server.second << ')';
delete this; }
} }
os << '}';
void WeightedRoundRobinLoadBalancer::Describe( }
std::ostream &os, const DescribeOptions& options) {
if (!options.verbose) { } // namespace policy
os << "wrr"; } // namespace brpc
return;
}
os << "WeightedRoundRobin{";
butil::DoublyBufferedData<Servers, TLS>::ScopedPtr s;
if (_db_servers.Read(&s) != 0) {
os << "fail to read _db_servers";
} else {
os << "n=" << s->server_list.size() << ':';
for (const auto& server : s->server_list) {
os << ' ' << server.first << '(' << server.second << ')';
}
}
os << '}';
}
} // namespace policy
} // namespace brpc
// Copyright (c) 2018 Iqiyi, Inc. // Copyright (c) 2018 Iqiyi, Inc.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Authors: Daojin Cai (caidaojin@qiyi.com) // Authors: Daojin Cai (caidaojin@qiyi.com)
#ifndef BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H #ifndef BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
#define BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H #define BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
#include <map> #include <map>
#include <vector> #include <vector>
#include "butil/containers/doubly_buffered_data.h" #include "butil/containers/doubly_buffered_data.h"
#include "brpc/load_balancer.h" #include "brpc/load_balancer.h"
namespace brpc { namespace brpc {
namespace policy { namespace policy {
// This LoadBalancer selects server as the assigned weight. // This LoadBalancer selects server as the assigned weight.
// Weight is got from tag of ServerId. // Weight is got from tag of ServerId.
class WeightedRoundRobinLoadBalancer : public LoadBalancer { class WeightedRoundRobinLoadBalancer : public LoadBalancer {
public: public:
bool AddServer(const ServerId& id); bool AddServer(const ServerId& id);
bool RemoveServer(const ServerId& id); bool RemoveServer(const ServerId& id);
size_t AddServersInBatch(const std::vector<ServerId>& servers); size_t AddServersInBatch(const std::vector<ServerId>& servers);
size_t RemoveServersInBatch(const std::vector<ServerId>& servers); size_t RemoveServersInBatch(const std::vector<ServerId>& servers);
int SelectServer(const SelectIn& in, SelectOut* out); int SelectServer(const SelectIn& in, SelectOut* out);
LoadBalancer* New() const; LoadBalancer* New() const;
void Destroy(); void Destroy();
void Describe(std::ostream&, const DescribeOptions& options); void Describe(std::ostream&, const DescribeOptions& options);
private: private:
struct Servers { struct Servers {
// The value is configured weight for each server. // The value is configured weight for each server.
std::vector<std::pair<SocketId, int>> server_list; std::vector<std::pair<SocketId, int>> server_list;
// The value is the index of the server in "server_list". // The value is the index of the server in "server_list".
std::map<SocketId, size_t> server_map; std::map<SocketId, size_t> server_map;
uint32_t weight_sum = 0; uint32_t weight_sum = 0;
}; };
struct TLS { struct TLS {
TLS(): remain_server(0, 0) { } TLS(): remain_server(0, 0) { }
uint32_t position = 0; uint32_t position = 0;
uint32_t stride = 0; uint32_t stride = 0;
uint32_t offset = 0; std::pair<SocketId, int> remain_server;
std::pair<SocketId, int> remain_server; bool HasRemainServer() const {
bool HasRemainServer() const { return remain_server.second != 0;
return remain_server.second != 0; }
} void SetRemainServer(const SocketId id, const int weight) {
void SetRemainServer(const SocketId id, const int weight) { remain_server.first = id;
remain_server.first = id; remain_server.second = weight;
remain_server.second = weight; }
} void ResetRemainServer() {
void ResetRemainServer() { remain_server.second = 0;
remain_server.second = 0; }
} void UpdatePosition(const uint32_t size) {
void UpdatePosition(const uint32_t size) { ++position;
++position; position %= size;
position %= size; }
} // If server list changed, we need caculate a new stride.
// If server list changed, we need caculate a new stride. bool IsNeededCaculateNewStride(const uint32_t curr_weight_sum,
bool IsNeededCaculateNewStride(const uint32_t curr_weight_sum, const uint32_t curr_servers_num) {
const uint32_t curr_servers_num) { if (curr_weight_sum != weight_sum
if (curr_weight_sum != weight_sum || curr_servers_num != servers_num) {
|| curr_servers_num != servers_num) { weight_sum = curr_weight_sum;
weight_sum = curr_weight_sum; servers_num = curr_servers_num;
servers_num = curr_servers_num; return true;
return true; }
} return false;
return false; }
} private:
private: uint32_t weight_sum = 0;
uint32_t weight_sum = 0; uint32_t servers_num = 0;
uint32_t servers_num = 0; };
}; static bool Add(Servers& bg, const ServerId& id);
static bool Add(Servers& bg, const ServerId& id); static bool Remove(Servers& bg, const ServerId& id);
static bool Remove(Servers& bg, const ServerId& id); static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchAdd(Servers& bg, const std::vector<ServerId>& servers); static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers);
static size_t BatchRemove(Servers& bg, const std::vector<ServerId>& servers); static int64_t GetBestServer(
static int64_t GetBestServer( const std::vector<std::pair<SocketId, int>>& server_list,
const std::vector<std::pair<SocketId, int>>& server_list, TLS& tls);
TLS& tls); // Get a reasonable stride according to weights configured of servers.
// Get a reasonable stride according to weights configured of servers. static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num);
static uint32_t GetStride(const uint32_t weight_sum, const uint32_t num); static void TryToGetFinalServer(const TLS& tls,
static void TryToGetFinalServer(const TLS& tls, const std::pair<SocketId, int> server,
const std::pair<SocketId, int> server, uint32_t& comp_weight, int64_t* final_server);
uint32_t& comp_weight, int64_t* final_server);
butil::DoublyBufferedData<Servers, TLS> _db_servers;
butil::DoublyBufferedData<Servers, TLS> _db_servers; };
};
} // namespace policy
} // namespace policy } // namespace brpc
} // namespace brpc
#endif // BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
#endif // BRPC_POLICY_WEIGHTED_ROUND_ROBIN_LOAD_BALANCER_H
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment