Commit 7883472f authored by zhujiashun's avatar zhujiashun

Fix issues in GOAWAY impl and add replated ut

parent e67bd7f8
......@@ -325,9 +325,10 @@ H2Context::H2Context(Socket* socket, const Server* server)
: _socket(socket)
, _remote_window_left(H2Settings::DEFAULT_INITIAL_WINDOW_SIZE)
, _conn_state(H2_CONNECTION_UNINITIALIZED)
, _last_server_stream_id(-1)
, _last_client_stream_id(1)
, _goaway_stream_id(-1)
, _last_receive_stream_id(-1)
, _last_send_stream_id(1)
, _goaway_received(false)
, _goaway_sent(false)
, _deferred_window_update(0) {
// Stop printing the field which is useless for remote settings.
_remote_settings.connection_window_size = 0;
......@@ -340,9 +341,9 @@ H2Context::H2Context(Socket* socket, const Server* server)
_unack_local_settings.connection_window_size = FLAGS_h2_client_connection_window_size;
}
#if defined(UNIT_TEST)
// In ut, we hope _last_client_stream_id run out quickly to test the correctness
// In ut, we hope _last_send_stream_id run out quickly to test the correctness
// of creating new h2 socket. This value is 10,000 less than 0x7FFFFFFF.
_last_client_stream_id = 0x7fffd8ef;
_last_send_stream_id = 0x7fffd8ef;
#endif
}
......@@ -384,12 +385,11 @@ H2StreamContext* H2Context::RemoveStream(int stream_id) {
void H2Context::RemoveGoAwayStreams(
int goaway_stream_id, std::vector<H2StreamContext*>* out_streams) {
out_streams->clear();
if (goaway_stream_id == 0) { // quick path
StreamMap tmp;
{
std::unique_lock<butil::Mutex> mu(_stream_mutex);
_goaway_stream_id = goaway_stream_id;
_goaway_received = true;
_pending_streams.swap(tmp);
}
for (StreamMap::const_iterator it = tmp.begin(); it != tmp.end(); ++it) {
......@@ -397,7 +397,7 @@ void H2Context::RemoveGoAwayStreams(
}
} else {
std::unique_lock<butil::Mutex> mu(_stream_mutex);
_goaway_stream_id = goaway_stream_id;
_goaway_received = true;
for (StreamMap::const_iterator it = _pending_streams.begin();
it != _pending_streams.end(); ++it) {
if (it->first > goaway_stream_id) {
......@@ -421,7 +421,7 @@ H2StreamContext* H2Context::FindStream(int stream_id) {
int H2Context::TryToInsertStream(int stream_id, H2StreamContext* ctx) {
std::unique_lock<butil::Mutex> mu(_stream_mutex);
if (_goaway_stream_id >= 0 && stream_id > _goaway_stream_id) {
if (_goaway_received) {
return 1;
}
H2StreamContext*& sctx = _pending_streams[stream_id];
......@@ -524,14 +524,15 @@ ParseResult H2Context::Consume(
}
return MakeMessage(NULL);
} else { // send GOAWAY
char goawaybuf[FRAME_HEAD_SIZE + 4];
char goawaybuf[FRAME_HEAD_SIZE + 8];
SerializeFrameHead(goawaybuf, 8, H2_FRAME_GOAWAY, 0, 0);
SaveUint32(goawaybuf + FRAME_HEAD_SIZE, 0/*last-stream-id*/);
SaveUint32(goawaybuf + FRAME_HEAD_SIZE, _last_receive_stream_id);
SaveUint32(goawaybuf + FRAME_HEAD_SIZE + 4, h2_res.error());
if (WriteAck(_socket, goawaybuf, sizeof(goawaybuf)) != 0) {
LOG(WARNING) << "Fail to send GOAWAY to " << *_socket;
return MakeParseError(PARSE_ERROR_ABSOLUTELY_WRONG);
}
_goaway_sent = true;
return MakeMessage(NULL);
}
} else {
......@@ -548,6 +549,12 @@ H2ParseResult H2Context::OnHeaders(
LOG(ERROR) << "Invalid stream_id=" << frame_head.stream_id;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (_goaway_sent) {
// TODO(zhujiashun): After sending a GOAWAY frame, the sender can discard
// frames for streams initiated by the receiver with identifiers higher
// than the identified last stream.
// Do we really need this strict check?
}
const bool has_padding = (frame_head.flags & H2_FLAGS_PADDED);
const bool has_priority = (frame_head.flags & H2_FLAGS_PRIORITY);
if (frame_head.payload_size <
......@@ -573,22 +580,13 @@ H2ParseResult H2Context::OnHeaders(
frag_size -= pad_length;
H2StreamContext* sctx = NULL;
if (is_server_side() &&
frame_head.stream_id > _last_server_stream_id) { // new stream
frame_head.stream_id > _last_receive_stream_id) { // new stream
if ((frame_head.stream_id & 1) == 0) {
LOG(ERROR) << "stream_id=" << frame_head.stream_id
<< " created by client is not odd";
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (((frame_head.stream_id - _last_server_stream_id) & 1) != 0) {
LOG(ERROR) << "Invalid stream_id=" << frame_head.stream_id;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (VolatilePendingStreamSize() >= local_settings().max_concurrent_streams) {
LOG(ERROR) << "Reached max concurrent stream="
<< local_settings().max_concurrent_streams;
return MakeH2Error(H2_REFUSED_STREAM);
}
_last_server_stream_id = frame_head.stream_id;
_last_receive_stream_id = frame_head.stream_id;
sctx = new H2StreamContext(_socket->is_read_progressive());
sctx->Init(this, frame_head.stream_id);
const int rc = TryToInsertStream(frame_head.stream_id, sctx);
......@@ -939,13 +937,31 @@ static void* ProcessHttpResponseWrapper(void* void_arg) {
}
H2ParseResult H2Context::OnGoAway(
butil::IOBufBytesIterator&, const H2FrameHead& h) {
butil::IOBufBytesIterator& it, const H2FrameHead& h) {
if (h.payload_size < 8) {
LOG(ERROR) << "Invalid payload_size=" << h.payload_size;
return MakeH2Error(H2_FRAME_SIZE_ERROR);
}
if (h.stream_id != 0) {
LOG(ERROR) << "Invalid stream_id=" << h.stream_id;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
if (h.flags) {
LOG(ERROR) << "Invalid flags=" << h.flags;
return MakeH2Error(H2_PROTOCOL_ERROR);
}
// Skip Additional Debug Data
it.forward(h.payload_size - 8);
const int last_stream_id = static_cast<int>(LoadUint32(it));
const H2Error ALLOW_UNUSED h2_error = static_cast<H2Error>(LoadUint32(it));
// TODO(zhujiashun): client and server should unify the code.
// Server Push is not supported so it works fine now.
if (is_client_side()) {
// The socket will not be selected for further requests.
_socket->SetLogOff();
std::vector<H2StreamContext*> goaway_streams;
RemoveGoAwayStreams(h.stream_id, &goaway_streams);
RemoveGoAwayStreams(last_stream_id, &goaway_streams);
if (goaway_streams.empty()) {
return MakeH2Message(NULL);
}
......@@ -1007,11 +1023,8 @@ void H2Context::Describe(std::ostream& os, const DescribeOptions& opt) const {
}
const char sep = (opt.verbose ? '\n' : ' ');
os << "conn_state=" << H2ConnectionState2Str(_conn_state);
if (is_server_side()) {
os << sep << "last_server_stream_id=" << _last_server_stream_id;
} else {
os << sep << "last_client_stream_id=" << _last_client_stream_id;
}
os << sep << "last_receive_stream_id=" << _last_receive_stream_id;
os << sep << "last_send_stream_id=" << _last_send_stream_id;
os << sep << "deferred_window_update="
<< _deferred_window_update.load(butil::memory_order_relaxed)
<< sep << "remote_conn_window_left="
......
......@@ -380,9 +380,10 @@ friend void InitFrameHandlers();
Socket* _socket;
butil::atomic<int64_t> _remote_window_left;
H2ConnectionState _conn_state;
int _last_server_stream_id;
uint32_t _last_client_stream_id;
int _goaway_stream_id;
int _last_receive_stream_id;
uint32_t _last_send_stream_id;
bool _goaway_received;
bool _goaway_sent;
H2Settings _remote_settings;
H2Settings _local_settings;
H2Settings _unack_local_settings;
......@@ -401,13 +402,13 @@ inline int H2Context::AllocateClientStreamId() {
<< _last_client_stream_id;
return -1;
}
const int id = _last_client_stream_id;
_last_client_stream_id += 2;
const int id = _last_send_stream_id;
_last_send_stream_id += 2;
return id;
}
inline bool H2Context::RunOutStreams() const {
return (_last_client_stream_id > 0x7FFFFFFF);
return (_last_send_stream_id > 0x7FFFFFFF);
}
inline std::ostream& operator<<(std::ostream& os, const H2UnsentRequest& req) {
......
......@@ -1324,4 +1324,43 @@ TEST_F(HttpTest, http2_header_after_data) {
ASSERT_EQ(*user_defined2, "b");
}
TEST_F(HttpTest, http2_goaway) {
brpc::Controller cntl;
// Prepare request
butil::IOBuf req_out;
int h2_stream_id = 0;
MakeH2EchoRequestBuf(&req_out, &cntl, &h2_stream_id);
// Prepare response
butil::IOBuf res_out;
MakeH2EchoResponseBuf(&res_out, h2_stream_id);
// append goaway
char goawaybuf[9 /*FRAME_HEAD_SIZE*/ + 8];
brpc::policy::SerializeFrameHead(goawaybuf, 8, brpc::policy::H2_FRAME_GOAWAY, 0, 0);
SaveUint32(goawaybuf + 9, 0x7fffd8ef /*last stream id*/);
SaveUint32(goawaybuf + 13, brpc::H2_NO_ERROR);
res_out.append(goawaybuf, sizeof(goawaybuf));
// parse response
brpc::ParseResult res_pr =
brpc::policy::ParseH2Message(&res_out, _h2_client_sock.get(), false, NULL);
ASSERT_TRUE(res_pr.is_ok());
// process response
ProcessMessage(brpc::policy::ProcessHttpResponse, res_pr.message(), false);
ASSERT_TRUE(!cntl.Failed());
// parse GOAWAY
res_pr = brpc::policy::ParseH2Message(&res_out, _h2_client_sock.get(), false, NULL);
ASSERT_EQ(res_pr.error(), brpc::PARSE_ERROR_NOT_ENOUGH_DATA);
// Since GOAWAY has been received, the next request should fail
brpc::policy::H2UnsentRequest* h2_req = brpc::policy::H2UnsentRequest::New(&cntl);
cntl._current_call.stream_user_data = h2_req;
brpc::SocketMessage* socket_message = NULL;
brpc::policy::PackH2Request(NULL, &socket_message, cntl.call_id().value,
NULL, &cntl, butil::IOBuf(), NULL);
butil::IOBuf dummy;
butil::Status st = socket_message->AppendAndDestroySelf(&dummy, _h2_client_sock.get());
ASSERT_EQ(st.error_code(), brpc::ELOGOFF);
ASSERT_TRUE(st.error_data().ends_with("the connection just issued GOAWAY"));
}
} //namespace
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment