Commit f6d454b2 authored by Kenton Varda's avatar Kenton Varda

Allow an ExceptionCallback to control how new threads' ExceptionCallbacks are initialized.

parent 67b9ea88
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "debug.h" #include "debug.h"
#include "threadlocal.h" #include "threadlocal.h"
#include "miniposix.h" #include "miniposix.h"
#include "function.h"
#include <stdlib.h> #include <stdlib.h>
#include <exception> #include <exception>
#include <new> #include <new>
...@@ -649,6 +650,10 @@ ExceptionCallback::StackTraceMode ExceptionCallback::stackTraceMode() { ...@@ -649,6 +650,10 @@ ExceptionCallback::StackTraceMode ExceptionCallback::stackTraceMode() {
return next.stackTraceMode(); return next.stackTraceMode();
} }
Function<void(Function<void()>)> ExceptionCallback::getThreadInitializer() {
return next.getThreadInitializer();
}
class ExceptionCallback::RootExceptionCallback: public ExceptionCallback { class ExceptionCallback::RootExceptionCallback: public ExceptionCallback {
public: public:
RootExceptionCallback(): ExceptionCallback(*this) {} RootExceptionCallback(): ExceptionCallback(*this) {}
...@@ -703,6 +708,14 @@ public: ...@@ -703,6 +708,14 @@ public:
#endif #endif
} }
Function<void(Function<void()>)> getThreadInitializer() override {
return [](Function<void()> func) {
// No initialization needed since RootExceptionCallback is automatically the root callback
// for new threads.
func();
};
}
private: private:
void logException(LogSeverity severity, Exception&& e) { void logException(LogSeverity severity, Exception&& e) {
// We intentionally go back to the top exception callback on the stack because we don't want to // We intentionally go back to the top exception callback on the stack because we don't want to
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
namespace kj { namespace kj {
class ExceptionImpl; class ExceptionImpl;
template <typename T> class Function;
class Exception { class Exception {
// Exception thrown in case of fatal errors. // Exception thrown in case of fatal errors.
...@@ -216,6 +217,11 @@ public: ...@@ -216,6 +217,11 @@ public:
virtual StackTraceMode stackTraceMode(); virtual StackTraceMode stackTraceMode();
// Returns the current preferred stack trace mode. // Returns the current preferred stack trace mode.
virtual Function<void(Function<void()>)> getThreadInitializer();
// Called just before a new thread is spawned using kj::Thread. Returns a function which should
// be invoked inside the new thread to initialize the thread's ExceptionCallback. The initializer
// function itself receives, as its parameter, the thread's main function, which it must call.
protected: protected:
ExceptionCallback& next; ExceptionCallback& next;
...@@ -224,6 +230,8 @@ private: ...@@ -224,6 +230,8 @@ private:
class RootExceptionCallback; class RootExceptionCallback;
friend ExceptionCallback& getExceptionCallback(); friend ExceptionCallback& getExceptionCallback();
friend class Thread;
}; };
ExceptionCallback& getExceptionCallback(); ExceptionCallback& getExceptionCallback();
......
...@@ -86,5 +86,39 @@ KJ_TEST("detaching thread doesn't delete function") { ...@@ -86,5 +86,39 @@ KJ_TEST("detaching thread doesn't delete function") {
} }
} }
class CapturingExceptionCallback final: public ExceptionCallback {
public:
CapturingExceptionCallback(String& target): target(target) {}
void logMessage(LogSeverity severity, const char* file, int line, int contextDepth,
String&& text) {
target = kj::mv(text);
}
private:
String& target;
};
class ThreadedExceptionCallback final: public ExceptionCallback {
public:
Function<void(Function<void()>)> getThreadInitializer() override {
return [this](Function<void()> func) {
CapturingExceptionCallback context(captured);
func();
};
}
String captured;
};
KJ_TEST("threads pick up exception callback initializer") {
ThreadedExceptionCallback context;
KJ_EXPECT(context.captured != "foobar");
Thread([]() {
KJ_LOG(ERROR, "foobar");
});
KJ_EXPECT(context.captured == "foobar", context.captured);
}
} // namespace } // namespace
} // namespace kj } // namespace kj
...@@ -112,6 +112,7 @@ void Thread::detach() { ...@@ -112,6 +112,7 @@ void Thread::detach() {
Thread::ThreadState::ThreadState(Function<void()> func) Thread::ThreadState::ThreadState(Function<void()> func)
: func(kj::mv(func)), : func(kj::mv(func)),
initializer(getExceptionCallback().getThreadInitializer()),
exception(nullptr), exception(nullptr),
refcount(2) {} refcount(2) {}
...@@ -138,7 +139,7 @@ void* Thread::runThread(void* ptr) { ...@@ -138,7 +139,7 @@ void* Thread::runThread(void* ptr) {
#endif #endif
ThreadState* state = reinterpret_cast<ThreadState*>(ptr); ThreadState* state = reinterpret_cast<ThreadState*>(ptr);
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
state->func(); state->initializer(kj::mv(state->func));
})) { })) {
state->exception = kj::mv(*exception); state->exception = kj::mv(*exception);
} }
......
...@@ -56,6 +56,7 @@ private: ...@@ -56,6 +56,7 @@ private:
ThreadState(Function<void()> func); ThreadState(Function<void()> func);
Function<void()> func; Function<void()> func;
Function<void(Function<void()>)> initializer;
kj::Maybe<kj::Exception> exception; kj::Maybe<kj::Exception> exception;
unsigned int refcount; unsigned int refcount;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment