// Copyright (c) 2018 Kenton Varda and contributors // Licensed under the MIT License: // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. #include "json-rpc.h" #include <kj/compat/http.h> #include <capnp/compat/json-rpc.capnp.h> namespace capnp { static constexpr uint64_t JSON_NAME_ANNOTATION_ID = 0xfa5b1fd61c2e7c3dull; static constexpr uint64_t JSON_NOTIFICATION_ANNOTATION_ID = 0xa0a054dea32fd98cull; class JsonRpc::CapabilityImpl final: public DynamicCapability::Server { public: CapabilityImpl(JsonRpc& parent, InterfaceSchema schema) : DynamicCapability::Server(schema), parent(parent) {} kj::Promise<void> call(InterfaceSchema::Method method, CallContext<DynamicStruct, DynamicStruct> context) override { auto proto = method.getProto(); bool isNotification = false; kj::StringPtr name = proto.getName(); for (auto annotation: proto.getAnnotations()) { switch (annotation.getId()) { case JSON_NAME_ANNOTATION_ID: name = annotation.getValue().getText(); break; case JSON_NOTIFICATION_ANNOTATION_ID: isNotification = true; break; } } capnp::MallocMessageBuilder message; auto value = message.getRoot<json::Value>(); auto list = value.initObject(3 + !isNotification); uint index = 0; auto jsonrpc = list[index++]; jsonrpc.setName("jsonrpc"); jsonrpc.initValue().setString("2.0"); uint callId = parent.callCount++; if (!isNotification) { auto id = list[index++]; id.setName("id"); id.initValue().setNumber(callId); } auto methodName = list[index++]; methodName.setName("method"); methodName.initValue().setString(name); auto params = list[index++]; params.setName("params"); parent.codec.encode(context.getParams(), params.initValue()); auto writePromise = parent.queueWrite(parent.codec.encode(value)); if (isNotification) { auto sproto = context.getResultsType().getProto().getStruct(); MessageSize size { sproto.getDataWordCount(), sproto.getPointerCount() }; context.initResults(size); return kj::mv(writePromise); } else { auto paf = kj::newPromiseAndFulfiller<void>(); parent.awaitedResponses.insert(callId, AwaitedResponse { context, kj::mv(paf.fulfiller) }); auto promise = writePromise.then([p = kj::mv(paf.promise)]() mutable { return kj::mv(p); }); auto& parentRef = parent; return promise.attach(kj::defer([&parentRef,callId]() { parentRef.awaitedResponses.erase(callId); })); } } private: JsonRpc& parent; }; JsonRpc::JsonRpc(Transport& transport, DynamicCapability::Client interface) : JsonRpc(transport, kj::mv(interface), kj::newPromiseAndFulfiller<void>()) {} JsonRpc::JsonRpc(Transport& transport, DynamicCapability::Client interfaceParam, kj::PromiseFulfillerPair<void> paf) : transport(transport), interface(kj::mv(interfaceParam)), errorPromise(paf.promise.fork()), errorFulfiller(kj::mv(paf.fulfiller)), readTask(readLoop().eagerlyEvaluate([this](kj::Exception&& e) { errorFulfiller->reject(kj::mv(e)); })), tasks(*this) { codec.handleByAnnotation(interface.getSchema()); codec.handleByAnnotation<json::RpcMessage>(); for (auto method: interface.getSchema().getMethods()) { auto proto = method.getProto(); kj::StringPtr name = proto.getName(); for (auto annotation: proto.getAnnotations()) { switch (annotation.getId()) { case JSON_NAME_ANNOTATION_ID: name = annotation.getValue().getText(); break; } } methodMap.insert(name, method); } } DynamicCapability::Client JsonRpc::getPeer(InterfaceSchema schema) { codec.handleByAnnotation(interface.getSchema()); return kj::heap<CapabilityImpl>(*this, schema); } static kj::HttpHeaderTable& staticHeaderTable() { static kj::HttpHeaderTable HEADER_TABLE; return HEADER_TABLE; } kj::Promise<void> JsonRpc::queueWrite(kj::String text) { auto fork = writeQueue.then([this, text = kj::mv(text)]() mutable { auto promise = transport.send(text); return promise.attach(kj::mv(text)); }).eagerlyEvaluate([this](kj::Exception&& e) { errorFulfiller->reject(kj::mv(e)); }).fork(); writeQueue = fork.addBranch(); return fork.addBranch(); } void JsonRpc::queueError(kj::Maybe<json::Value::Reader> id, int code, kj::StringPtr message) { MallocMessageBuilder capnpMessage; auto jsonResponse = capnpMessage.getRoot<json::RpcMessage>(); jsonResponse.setJsonrpc("2.0"); KJ_IF_MAYBE(i, id) { jsonResponse.setId(*i); } else { jsonResponse.initId().setNull(); } auto error = jsonResponse.initError(); error.setCode(code); error.setMessage(message); // OK to discard result of queueWrite() since it's just one branch of a fork. queueWrite(codec.encode(jsonResponse)); } kj::Promise<void> JsonRpc::readLoop() { return transport.receive().then([this](kj::String message) -> kj::Promise<void> { MallocMessageBuilder capnpMessage; auto rpcMessageBuilder = capnpMessage.getRoot<json::RpcMessage>(); KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { codec.decode(message, rpcMessageBuilder); })) { queueError(nullptr, -32700, kj::str("Parse error: ", exception->getDescription())); return readLoop(); } KJ_CONTEXT("decoding JSON-RPC message", message); auto rpcMessage = rpcMessageBuilder.asReader(); if (!rpcMessage.hasJsonrpc()) { queueError(nullptr, -32700, kj::str("Missing 'jsonrpc' field.")); return readLoop(); } else if (rpcMessage.getJsonrpc() != "2.0") { queueError(nullptr, -32700, kj::str("Unknown JSON-RPC version. This peer implements version '2.0'.")); return readLoop(); } switch (rpcMessage.which()) { case json::RpcMessage::NONE: queueError(nullptr, -32700, kj::str("message has none of params, result, or error")); break; case json::RpcMessage::PARAMS: { // a call auto schema = interface.getSchema(); KJ_IF_MAYBE(method, schema.findMethodByName(rpcMessage.getMethod())) { auto req = interface.newRequest(*method); KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { codec.decode(rpcMessage.getParams(), req); })) { kj::Maybe<JsonValue::Reader> id; if (rpcMessage.hasId()) id = rpcMessage.getId(); queueError(id, -32602, kj::str("Type error in method params: ", exception->getDescription())); break; } if (rpcMessage.hasId()) { auto id = rpcMessage.getId(); auto idCopy = kj::heapArray<word>(id.totalSize().wordCount + 1); memset(idCopy.begin(), 0, idCopy.asBytes().size()); copyToUnchecked(id, idCopy); auto idPtr = readMessageUnchecked<json::Value>(idCopy.begin()); auto promise = req.send() .then([this,idPtr](Response<DynamicStruct> response) mutable { MallocMessageBuilder capnpMessage; auto jsonResponse = capnpMessage.getRoot<json::RpcMessage>(); jsonResponse.setJsonrpc("2.0"); jsonResponse.setId(idPtr); codec.encode(DynamicStruct::Reader(response), jsonResponse.initResult()); return queueWrite(codec.encode(jsonResponse)); }, [this,idPtr](kj::Exception&& e) { MallocMessageBuilder capnpMessage; auto jsonResponse = capnpMessage.getRoot<json::RpcMessage>(); jsonResponse.setJsonrpc("2.0"); jsonResponse.setId(idPtr); auto error = jsonResponse.initError(); switch (e.getType()) { case kj::Exception::Type::FAILED: error.setCode(-32000); break; case kj::Exception::Type::DISCONNECTED: error.setCode(-32001); break; case kj::Exception::Type::OVERLOADED: error.setCode(-32002); break; case kj::Exception::Type::UNIMPLEMENTED: error.setCode(-32601); // method not found break; } error.setMessage(e.getDescription()); return queueWrite(codec.encode(jsonResponse)); }); tasks.add(promise.attach(kj::mv(idCopy))); } else { // No 'id', so this is a notification. tasks.add(req.send().ignoreResult().catch_([](kj::Exception&& exception) { if (exception.getType() != kj::Exception::Type::UNIMPLEMENTED) { KJ_LOG(ERROR, "JSON-RPC notification threw exception into the abyss", exception); } })); } } else { if (rpcMessage.hasId()) { queueError(rpcMessage.getId(), -32601, "Method not found"); } else { // Ignore notification for unknown method. } } break; } case json::RpcMessage::RESULT: { auto id = rpcMessage.getId(); if (!id.isNumber()) { // JSON-RPC doesn't define what to do if receiving a response with an invalid id. KJ_LOG(ERROR, "JSON-RPC response has invalid ID"); } else KJ_IF_MAYBE(awaited, awaitedResponses.find((uint)id.getNumber())) { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { codec.decode(rpcMessage.getResult(), awaited->context.getResults()); awaited->fulfiller->fulfill(); })) { // Errors always propagate from callee to caller, so we don't want to throw this error // back to the server. awaited->fulfiller->reject(kj::mv(*exception)); } } else { // Probably, this is the response to a call that was canceled. } break; } case json::RpcMessage::ERROR: { auto id = rpcMessage.getId(); if (id.isNull()) { // Error message will be logged by KJ_CONTEXT, above. KJ_LOG(ERROR, "peer reports JSON-RPC protocol error"); } else if (!id.isNumber()) { // JSON-RPC doesn't define what to do if receiving a response with an invalid id. KJ_LOG(ERROR, "JSON-RPC response has invalid ID"); } else KJ_IF_MAYBE(awaited, awaitedResponses.find((uint)id.getNumber())) { auto error = rpcMessage.getError(); auto code = error.getCode(); kj::Exception::Type type = code == -32601 ? kj::Exception::Type::UNIMPLEMENTED : kj::Exception::Type::FAILED; awaited->fulfiller->reject(kj::Exception( type, __FILE__, __LINE__, kj::str(error.getMessage()))); } else { // Probably, this is the response to a call that was canceled. } break; } } return readLoop(); }); } void JsonRpc::taskFailed(kj::Exception&& exception) { errorFulfiller->reject(kj::mv(exception)); } // ======================================================================================= JsonRpc::ContentLengthTransport::ContentLengthTransport(kj::AsyncIoStream& stream) : stream(stream), input(kj::newHttpInputStream(stream, staticHeaderTable())) {} JsonRpc::ContentLengthTransport::~ContentLengthTransport() noexcept(false) {} kj::Promise<void> JsonRpc::ContentLengthTransport::send(kj::StringPtr text) { auto headers = kj::str("Content-Length: ", text.size(), "\r\n\r\n"); parts[0] = headers.asBytes(); parts[1] = text.asBytes(); return stream.write(parts).attach(kj::mv(headers)); } kj::Promise<kj::String> JsonRpc::ContentLengthTransport::receive() { return input->readMessage() .then([](kj::HttpInputStream::Message&& message) { auto promise = message.body->readAllText(); return promise.attach(kj::mv(message.body)); }); } } // namespace capnp