Commit 9b5fc00a authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #526 from capnproto/exception-callback-thread-initializer

Allow an ExceptionCallback to control how new threads' ExceptionCallbacks are initialized.
parents cc74158d f6d454b2
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "debug.h" #include "debug.h"
#include "threadlocal.h" #include "threadlocal.h"
#include "miniposix.h" #include "miniposix.h"
#include "function.h"
#include <stdlib.h> #include <stdlib.h>
#include <exception> #include <exception>
#include <new> #include <new>
...@@ -649,6 +650,10 @@ ExceptionCallback::StackTraceMode ExceptionCallback::stackTraceMode() { ...@@ -649,6 +650,10 @@ ExceptionCallback::StackTraceMode ExceptionCallback::stackTraceMode() {
return next.stackTraceMode(); return next.stackTraceMode();
} }
Function<void(Function<void()>)> ExceptionCallback::getThreadInitializer() {
return next.getThreadInitializer();
}
class ExceptionCallback::RootExceptionCallback: public ExceptionCallback { class ExceptionCallback::RootExceptionCallback: public ExceptionCallback {
public: public:
RootExceptionCallback(): ExceptionCallback(*this) {} RootExceptionCallback(): ExceptionCallback(*this) {}
...@@ -703,6 +708,14 @@ public: ...@@ -703,6 +708,14 @@ public:
#endif #endif
} }
Function<void(Function<void()>)> getThreadInitializer() override {
return [](Function<void()> func) {
// No initialization needed since RootExceptionCallback is automatically the root callback
// for new threads.
func();
};
}
private: private:
void logException(LogSeverity severity, Exception&& e) { void logException(LogSeverity severity, Exception&& e) {
// We intentionally go back to the top exception callback on the stack because we don't want to // We intentionally go back to the top exception callback on the stack because we don't want to
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
namespace kj { namespace kj {
class ExceptionImpl; class ExceptionImpl;
template <typename T> class Function;
class Exception { class Exception {
// Exception thrown in case of fatal errors. // Exception thrown in case of fatal errors.
...@@ -216,6 +217,11 @@ public: ...@@ -216,6 +217,11 @@ public:
virtual StackTraceMode stackTraceMode(); virtual StackTraceMode stackTraceMode();
// Returns the current preferred stack trace mode. // Returns the current preferred stack trace mode.
virtual Function<void(Function<void()>)> getThreadInitializer();
// Called just before a new thread is spawned using kj::Thread. Returns a function which should
// be invoked inside the new thread to initialize the thread's ExceptionCallback. The initializer
// function itself receives, as its parameter, the thread's main function, which it must call.
protected: protected:
ExceptionCallback& next; ExceptionCallback& next;
...@@ -224,6 +230,8 @@ private: ...@@ -224,6 +230,8 @@ private:
class RootExceptionCallback; class RootExceptionCallback;
friend ExceptionCallback& getExceptionCallback(); friend ExceptionCallback& getExceptionCallback();
friend class Thread;
}; };
ExceptionCallback& getExceptionCallback(); ExceptionCallback& getExceptionCallback();
......
...@@ -86,5 +86,39 @@ KJ_TEST("detaching thread doesn't delete function") { ...@@ -86,5 +86,39 @@ KJ_TEST("detaching thread doesn't delete function") {
} }
} }
class CapturingExceptionCallback final: public ExceptionCallback {
public:
CapturingExceptionCallback(String& target): target(target) {}
void logMessage(LogSeverity severity, const char* file, int line, int contextDepth,
String&& text) {
target = kj::mv(text);
}
private:
String& target;
};
class ThreadedExceptionCallback final: public ExceptionCallback {
public:
Function<void(Function<void()>)> getThreadInitializer() override {
return [this](Function<void()> func) {
CapturingExceptionCallback context(captured);
func();
};
}
String captured;
};
KJ_TEST("threads pick up exception callback initializer") {
ThreadedExceptionCallback context;
KJ_EXPECT(context.captured != "foobar");
Thread([]() {
KJ_LOG(ERROR, "foobar");
});
KJ_EXPECT(context.captured == "foobar", context.captured);
}
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -34,7 +34,7 @@ namespace kj { ...@@ -34,7 +34,7 @@ namespace kj {
#if _WIN32 #if _WIN32
Thread::Thread(Function<void()> func): state(new ThreadState { kj::mv(func), nullptr, 2 }) { Thread::Thread(Function<void()> func): state(new ThreadState(kj::mv(func))) {
threadHandle = CreateThread(nullptr, 0, &runThread, state, 0, nullptr); threadHandle = CreateThread(nullptr, 0, &runThread, state, 0, nullptr);
if (threadHandle == nullptr) { if (threadHandle == nullptr) {
state->unref(); state->unref();
...@@ -61,20 +61,9 @@ void Thread::detach() { ...@@ -61,20 +61,9 @@ void Thread::detach() {
detached = true; detached = true;
} }
DWORD Thread::runThread(void* ptr) {
ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->func();
})) {
state->exception = kj::mv(*exception);
}
state->unref();
return 0;
}
#else // _WIN32 #else // _WIN32
Thread::Thread(Function<void()> func): state(new ThreadState { kj::mv(func), nullptr, 2 }) { Thread::Thread(Function<void()> func): state(new ThreadState(kj::mv(func))) {
static_assert(sizeof(threadId) >= sizeof(pthread_t), static_assert(sizeof(threadId) >= sizeof(pthread_t),
"pthread_t is larger than a long long on your platform. Please port."); "pthread_t is larger than a long long on your platform. Please port.");
...@@ -119,19 +108,14 @@ void Thread::detach() { ...@@ -119,19 +108,14 @@ void Thread::detach() {
state->unref(); state->unref();
} }
void* Thread::runThread(void* ptr) {
ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->func();
})) {
state->exception = kj::mv(*exception);
}
state->unref();
return nullptr;
}
#endif // _WIN32, else #endif // _WIN32, else
Thread::ThreadState::ThreadState(Function<void()> func)
: func(kj::mv(func)),
initializer(getExceptionCallback().getThreadInitializer()),
exception(nullptr),
refcount(2) {}
void Thread::ThreadState::unref() { void Thread::ThreadState::unref() {
#if _MSC_VER #if _MSC_VER
if (_InterlockedDecrement(&refcount) == 0) { if (_InterlockedDecrement(&refcount) == 0) {
...@@ -148,4 +132,19 @@ void Thread::ThreadState::unref() { ...@@ -148,4 +132,19 @@ void Thread::ThreadState::unref() {
} }
} }
#if _WIN32
DWORD Thread::runThread(void* ptr) {
#else
void* Thread::runThread(void* ptr) {
#endif
ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->initializer(kj::mv(state->func));
})) {
state->exception = kj::mv(*exception);
}
state->unref();
return 0;
}
} // namespace kj } // namespace kj
...@@ -53,7 +53,10 @@ public: ...@@ -53,7 +53,10 @@ public:
private: private:
struct ThreadState { struct ThreadState {
ThreadState(Function<void()> func);
Function<void()> func; Function<void()> func;
Function<void(Function<void()>)> initializer;
kj::Maybe<kj::Exception> exception; kj::Maybe<kj::Exception> exception;
unsigned int refcount; unsigned int refcount;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment