Unverified Commit ede2bf59 authored by Ge Jun's avatar Ge Jun Committed by GitHub

Merge pull request #523 from brpc/http2_goaway

Fix issues in GOAWAY impl and add related ut
parents d6d86cf0 585c081b
......@@ -325,8 +325,8 @@ H2Context::H2Context(Socket* socket, const Server* server)
: _socket(socket)
, _remote_window_left(H2Settings::DEFAULT_INITIAL_WINDOW_SIZE)
, _conn_state(H2_CONNECTION_UNINITIALIZED)
, _last_server_stream_id(-1)
, _last_client_stream_id(1)
, _last_received_stream_id(-1)
, _last_sent_stream_id(1)
, _goaway_stream_id(-1)
, _deferred_window_update(0) {
// Stop printing the field which is useless for remote settings.
......@@ -340,9 +340,9 @@ H2Context::H2Context(Socket* socket, const Server* server)
_unack_local_settings.connection_window_size = FLAGS_h2_client_connection_window_size;
}
#if defined(UNIT_TEST)
// In ut, we hope _last_client_stream_id run out quickly to test the correctness
// In ut, we hope _last_sent_stream_id run out quickly to test the correctness
// of creating new h2 socket. This value is 10,000 less than 0x7FFFFFFF.
_last_client_stream_id = 0x7fffd8ef;
_last_sent_stream_id = 0x7fffd8ef;
#endif
}
......@@ -384,7 +384,6 @@ H2StreamContext* H2Context::RemoveStream(int stream_id) {
void H2Context::RemoveGoAwayStreams(
int goaway_stream_id, std::vector<H2StreamContext*>* out_streams) {
out_streams->clear();
if (goaway_stream_id == 0) { // quick path
StreamMap tmp;
{
......@@ -524,9 +523,9 @@ ParseResult H2Context::Consume(
}
return MakeMessage(NULL);
} else { // send GOAWAY
char goawaybuf[FRAME_HEAD_SIZE + 4];
char goawaybuf[FRAME_HEAD_SIZE + 8];
SerializeFrameHead(goawaybuf, 8, H2_FRAME_GOAWAY, 0, 0);
SaveUint32(goawaybuf + FRAME_HEAD_SIZE, 0/*last-stream-id*/);
SaveUint32(goawaybuf + FRAME_HEAD_SIZE, _last_received_stream_id);
SaveUint32(goawaybuf + FRAME_HEAD_SIZE + 4, h2_res.error());
if (WriteAck(_socket, goawaybuf, sizeof(goawaybuf)) != 0) {
LOG(WARNING) << "Fail to send GOAWAY to " << *_socket;
......@@ -573,22 +572,13 @@ H2ParseResult H2Context::OnHeaders(
frag_size -= pad_length;
H2StreamContext* sctx = NULL;
if (is_server_side() &&
frame_head.stream_id > _last_server_stream_id) { // new stream
frame_head.stream_id > _last_received_stream_id) { // new stream
if ((frame_head.stream_id & 1) == 0) {
LOG(ERROR) << "stream_id=" << frame_head.stream_id
<< " created by client is not odd";
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (((frame_head.stream_id - _last_server_stream_id) & 1) != 0) {
LOG(ERROR) << "Invalid stream_id=" << frame_head.stream_id;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (VolatilePendingStreamSize() >= local_settings().max_concurrent_streams) {
LOG(ERROR) << "Reached max concurrent stream="
<< local_settings().max_concurrent_streams;
return MakeH2Error(H2_REFUSED_STREAM);
}
_last_server_stream_id = frame_head.stream_id;
_last_received_stream_id = frame_head.stream_id;
sctx = new H2StreamContext(_socket->is_read_progressive());
sctx->Init(this, frame_head.stream_id);
const int rc = TryToInsertStream(frame_head.stream_id, sctx);
......@@ -939,13 +929,31 @@ static void* ProcessHttpResponseWrapper(void* void_arg) {
}
H2ParseResult H2Context::OnGoAway(
butil::IOBufBytesIterator&, const H2FrameHead& h) {
butil::IOBufBytesIterator& it, const H2FrameHead& h) {
if (h.payload_size < 8) {
LOG(ERROR) << "Invalid payload_size=" << h.payload_size;
return MakeH2Error(H2_FRAME_SIZE_ERROR);
}
if (h.stream_id != 0) {
LOG(ERROR) << "Invalid stream_id=" << h.stream_id;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (h.flags) {
LOG(ERROR) << "Invalid flags=" << h.flags;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
// Skip Additional Debug Data
it.forward(h.payload_size - 8);
const int last_stream_id = static_cast<int>(LoadUint32(it));
const H2Error ALLOW_UNUSED h2_error = static_cast<H2Error>(LoadUint32(it));
// TODO(zhujiashun): client and server should unify the code.
// Server Push is not supported so it works fine now.
if (is_client_side()) {
// The socket will not be selected for further requests.
_socket->SetLogOff();
std::vector<H2StreamContext*> goaway_streams;
RemoveGoAwayStreams(h.stream_id, &goaway_streams);
RemoveGoAwayStreams(last_stream_id, &goaway_streams);
if (goaway_streams.empty()) {
return MakeH2Message(NULL);
}
......@@ -1007,11 +1015,8 @@ void H2Context::Describe(std::ostream& os, const DescribeOptions& opt) const {
}
const char sep = (opt.verbose ? '\n' : ' ');
os << "conn_state=" << H2ConnectionState2Str(_conn_state);
if (is_server_side()) {
os << sep << "last_server_stream_id=" << _last_server_stream_id;
} else {
os << sep << "last_client_stream_id=" << _last_client_stream_id;
}
os << sep << "last_received_stream_id=" << _last_received_stream_id
<< sep << "last_sent_stream_id=" << _last_sent_stream_id;
os << sep << "deferred_window_update="
<< _deferred_window_update.load(butil::memory_order_relaxed)
<< sep << "remote_conn_window_left="
......
......@@ -380,8 +380,8 @@ friend void InitFrameHandlers();
Socket* _socket;
butil::atomic<int64_t> _remote_window_left;
H2ConnectionState _conn_state;
int _last_server_stream_id;
uint32_t _last_client_stream_id;
int _last_received_stream_id;
uint32_t _last_sent_stream_id;
int _goaway_stream_id;
H2Settings _remote_settings;
H2Settings _local_settings;
......@@ -397,17 +397,17 @@ friend void InitFrameHandlers();
inline int H2Context::AllocateClientStreamId() {
if (RunOutStreams()) {
LOG(WARNING) << "Fail to allocate new client stream, _last_client_stream_id="
<< _last_client_stream_id;
LOG(WARNING) << "Fail to allocate new client stream, _last_sent_stream_id="
<< _last_sent_stream_id;
return -1;
}
const int id = _last_client_stream_id;
_last_client_stream_id += 2;
const int id = _last_sent_stream_id;
_last_sent_stream_id += 2;
return id;
}
inline bool H2Context::RunOutStreams() const {
return (_last_client_stream_id > 0x7FFFFFFF);
return (_last_sent_stream_id > 0x7FFFFFFF);
}
inline std::ostream& operator<<(std::ostream& os, const H2UnsentRequest& req) {
......
......@@ -33,7 +33,7 @@ TEST(H2UnsentMessage, request_throughput) {
new brpc::policy::H2Context(h2_client_sock.get(), NULL);
CHECK_EQ(ctx->Init(), 0);
h2_client_sock->initialize_parsing_context(&ctx);
ctx->_last_client_stream_id = 0;
ctx->_last_sent_stream_id = 0;
ctx->_remote_window_left = brpc::H2Settings::MAX_WINDOW_SIZE;
int64_t ntotal = 500000;
......
......@@ -1324,4 +1324,43 @@ TEST_F(HttpTest, http2_header_after_data) {
ASSERT_EQ(*user_defined2, "b");
}
TEST_F(HttpTest, http2_goaway) {
brpc::Controller cntl;
// Prepare request
butil::IOBuf req_out;
int h2_stream_id = 0;
MakeH2EchoRequestBuf(&req_out, &cntl, &h2_stream_id);
// Prepare response
butil::IOBuf res_out;
MakeH2EchoResponseBuf(&res_out, h2_stream_id);
// append goaway
char goawaybuf[9 /*FRAME_HEAD_SIZE*/ + 8];
brpc::policy::SerializeFrameHead(goawaybuf, 8, brpc::policy::H2_FRAME_GOAWAY, 0, 0);
SaveUint32(goawaybuf + 9, 0x7fffd8ef /*last stream id*/);
SaveUint32(goawaybuf + 13, brpc::H2_NO_ERROR);
res_out.append(goawaybuf, sizeof(goawaybuf));
// parse response
brpc::ParseResult res_pr =
brpc::policy::ParseH2Message(&res_out, _h2_client_sock.get(), false, NULL);
ASSERT_TRUE(res_pr.is_ok());
// process response
ProcessMessage(brpc::policy::ProcessHttpResponse, res_pr.message(), false);
ASSERT_TRUE(!cntl.Failed());
// parse GOAWAY
res_pr = brpc::policy::ParseH2Message(&res_out, _h2_client_sock.get(), false, NULL);
ASSERT_EQ(res_pr.error(), brpc::PARSE_ERROR_NOT_ENOUGH_DATA);
// Since GOAWAY has been received, the next request should fail
brpc::policy::H2UnsentRequest* h2_req = brpc::policy::H2UnsentRequest::New(&cntl);
cntl._current_call.stream_user_data = h2_req;
brpc::SocketMessage* socket_message = NULL;
brpc::policy::PackH2Request(NULL, &socket_message, cntl.call_id().value,
NULL, &cntl, butil::IOBuf(), NULL);
butil::IOBuf dummy;
butil::Status st = socket_message->AppendAndDestroySelf(&dummy, _h2_client_sock.get());
ASSERT_EQ(st.error_code(), brpc::ELOGOFF);
ASSERT_TRUE(st.error_data().ends_with("the connection just issued GOAWAY"));
}
} //namespace
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment