Commit 9b5fc00a authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #526 from capnproto/exception-callback-thread-initializer

Allow an ExceptionCallback to control how new threads' ExceptionCallbacks are initialized.
parents cc74158d f6d454b2
......@@ -24,6 +24,7 @@
#include "debug.h"
#include "threadlocal.h"
#include "miniposix.h"
#include "function.h"
#include <stdlib.h>
#include <exception>
#include <new>
......@@ -649,6 +650,10 @@ ExceptionCallback::StackTraceMode ExceptionCallback::stackTraceMode() {
return next.stackTraceMode();
}
Function<void(Function<void()>)> ExceptionCallback::getThreadInitializer() {
return next.getThreadInitializer();
}
class ExceptionCallback::RootExceptionCallback: public ExceptionCallback {
public:
RootExceptionCallback(): ExceptionCallback(*this) {}
......@@ -703,6 +708,14 @@ public:
#endif
}
Function<void(Function<void()>)> getThreadInitializer() override {
return [](Function<void()> func) {
// No initialization needed since RootExceptionCallback is automatically the root callback
// for new threads.
func();
};
}
private:
void logException(LogSeverity severity, Exception&& e) {
// We intentionally go back to the top exception callback on the stack because we don't want to
......
......@@ -33,6 +33,7 @@
namespace kj {
class ExceptionImpl;
template <typename T> class Function;
class Exception {
// Exception thrown in case of fatal errors.
......@@ -216,6 +217,11 @@ public:
virtual StackTraceMode stackTraceMode();
// Returns the current preferred stack trace mode.
virtual Function<void(Function<void()>)> getThreadInitializer();
// Called just before a new thread is spawned using kj::Thread. Returns a function which should
// be invoked inside the new thread to initialize the thread's ExceptionCallback. The initializer
// function itself receives, as its parameter, the thread's main function, which it must call.
protected:
ExceptionCallback& next;
......@@ -224,6 +230,8 @@ private:
class RootExceptionCallback;
friend ExceptionCallback& getExceptionCallback();
friend class Thread;
};
ExceptionCallback& getExceptionCallback();
......
......@@ -86,5 +86,39 @@ KJ_TEST("detaching thread doesn't delete function") {
}
}
class CapturingExceptionCallback final: public ExceptionCallback {
public:
CapturingExceptionCallback(String& target): target(target) {}
void logMessage(LogSeverity severity, const char* file, int line, int contextDepth,
String&& text) {
target = kj::mv(text);
}
private:
String& target;
};
class ThreadedExceptionCallback final: public ExceptionCallback {
public:
Function<void(Function<void()>)> getThreadInitializer() override {
return [this](Function<void()> func) {
CapturingExceptionCallback context(captured);
func();
};
}
String captured;
};
KJ_TEST("threads pick up exception callback initializer") {
ThreadedExceptionCallback context;
KJ_EXPECT(context.captured != "foobar");
Thread([]() {
KJ_LOG(ERROR, "foobar");
});
KJ_EXPECT(context.captured == "foobar", context.captured);
}
} // namespace
} // namespace kj
......@@ -34,7 +34,7 @@ namespace kj {
#if _WIN32
Thread::Thread(Function<void()> func): state(new ThreadState { kj::mv(func), nullptr, 2 }) {
Thread::Thread(Function<void()> func): state(new ThreadState(kj::mv(func))) {
threadHandle = CreateThread(nullptr, 0, &runThread, state, 0, nullptr);
if (threadHandle == nullptr) {
state->unref();
......@@ -61,20 +61,9 @@ void Thread::detach() {
detached = true;
}
DWORD Thread::runThread(void* ptr) {
ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->func();
})) {
state->exception = kj::mv(*exception);
}
state->unref();
return 0;
}
#else // _WIN32
Thread::Thread(Function<void()> func): state(new ThreadState { kj::mv(func), nullptr, 2 }) {
Thread::Thread(Function<void()> func): state(new ThreadState(kj::mv(func))) {
static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port.");
......@@ -119,19 +108,14 @@ void Thread::detach() {
state->unref();
}
void* Thread::runThread(void* ptr) {
ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->func();
})) {
state->exception = kj::mv(*exception);
}
state->unref();
return nullptr;
}
#endif // _WIN32, else
Thread::ThreadState::ThreadState(Function<void()> func)
: func(kj::mv(func)),
initializer(getExceptionCallback().getThreadInitializer()),
exception(nullptr),
refcount(2) {}
void Thread::ThreadState::unref() {
#if _MSC_VER
if (_InterlockedDecrement(&refcount) == 0) {
......@@ -148,4 +132,19 @@ void Thread::ThreadState::unref() {
}
}
#if _WIN32
DWORD Thread::runThread(void* ptr) {
#else
void* Thread::runThread(void* ptr) {
#endif
ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->initializer(kj::mv(state->func));
})) {
state->exception = kj::mv(*exception);
}
state->unref();
return 0;
}
} // namespace kj
......@@ -53,7 +53,10 @@ public:
private:
struct ThreadState {
ThreadState(Function<void()> func);
Function<void()> func;
Function<void(Function<void()>)> initializer;
kj::Maybe<kj::Exception> exception;
unsigned int refcount;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment