From 8341cc98c65fcc34a92aab60180de9febf332d6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Mon, 17 Nov 2025 16:32:23 +0100 Subject: [PATCH] refactor: rely on context.Context for log/slog and others (#1969) * refactor: rely on context.Context for log/slog and others * optimize * refactor * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix watcher-skip * better globals handling * fix --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- caddy/app.go | 12 ++- cgi.go | 4 +- context.go | 7 +- frankenphp.c | 2 +- frankenphp.go | 158 ++++++++++++++++++++++-------- frankenphp_test.go | 11 +++ internal/testserver/main.go | 10 +- internal/watcher/watch_pattern.go | 13 ++- internal/watcher/watcher-skip.go | 7 +- internal/watcher/watcher.go | 48 ++++++--- options.go | 11 +++ phpmainthread.go | 5 +- phpmainthread_test.go | 36 ++++--- phpthread.go | 19 ++-- scaling.go | 30 ++++-- threadinactive.go | 12 ++- threadregular.go | 54 ++++++---- threadtasks_test.go | 9 +- threadworker.go | 123 +++++++++++++++-------- types_test.go | 2 +- worker.go | 24 ++--- workerextension.go | 9 +- workerextension_test.go | 2 +- 23 files changed, 425 insertions(+), 183 deletions(-) diff --git a/caddy/app.go b/caddy/app.go index 02dcf170..52a88852 100644 --- a/caddy/app.go +++ b/caddy/app.go @@ -1,6 +1,7 @@ package caddy import ( + "context" "errors" "fmt" "log/slog" @@ -56,6 +57,7 @@ type FrankenPHPApp struct { MaxWaitTime time.Duration `json:"max_wait_time,omitempty"` metrics frankenphp.Metrics + ctx context.Context logger *slog.Logger } @@ -71,6 +73,7 @@ func (f FrankenPHPApp) CaddyModule() caddy.ModuleInfo { // Provision sets up the module. func (f *FrankenPHPApp) Provision(ctx caddy.Context) error { + f.ctx = ctx f.logger = ctx.Slogger() if httpApp, err := ctx.AppIfConfigured("http"); err == nil { @@ -128,9 +131,10 @@ func (f *FrankenPHPApp) Start() error { repl := caddy.NewReplacer() opts := []frankenphp.Option{ + frankenphp.WithContext(f.ctx), + frankenphp.WithLogger(f.logger), frankenphp.WithNumThreads(f.NumThreads), frankenphp.WithMaxThreads(f.MaxThreads), - frankenphp.WithLogger(f.logger), frankenphp.WithMetrics(f.metrics), frankenphp.WithPhpIni(f.PhpIni), frankenphp.WithMaxWaitTime(f.MaxWaitTime), @@ -159,7 +163,11 @@ func (f *FrankenPHPApp) Start() error { } func (f *FrankenPHPApp) Stop() error { - f.logger.Info("FrankenPHP stopped 🐘") + ctx := caddy.ActiveContext() + + if f.logger.Enabled(caddy.ActiveContext(), slog.LevelInfo) { + f.logger.LogAttrs(ctx, slog.LevelInfo, "FrankenPHP stopped 🐘") + } // attempt a graceful shutdown if caddy is exiting // note: Exiting() is currently marked as 'experimental' diff --git a/cgi.go b/cgi.go index 4c11a285..a04d4473 100644 --- a/cgi.go +++ b/cgi.go @@ -212,7 +212,7 @@ func addPreparedEnvToServer(fc *frankenPHPContext, trackVarsArray *C.zval) { //export go_register_variables func go_register_variables(threadIndex C.uintptr_t, trackVarsArray *C.zval) { thread := phpThreads[threadIndex] - fc := thread.getRequestContext() + fc := thread.frankenPHPContext() if fc.request != nil { addKnownVariablesToServer(thread, fc, trackVarsArray) @@ -279,7 +279,7 @@ func splitPos(path string, splitPath []string) int { //export go_update_request_info func go_update_request_info(threadIndex C.uintptr_t, info *C.sapi_request_info) C.bool { thread := phpThreads[threadIndex] - fc := thread.getRequestContext() + fc := thread.frankenPHPContext() request := fc.request if request == nil { diff --git a/context.go b/context.go index 543bb4a7..08e66bfe 100644 --- a/context.go +++ b/context.go @@ -38,6 +38,11 @@ type frankenPHPContext struct { startedAt time.Time } +type contextHolder struct { + ctx context.Context + frankenPHPContext *frankenPHPContext +} + // fromContext extracts the frankenPHPContext from a context. func fromContext(ctx context.Context) (fctx *frankenPHPContext, ok bool) { fctx, ok = ctx.Value(contextKey).(*frankenPHPContext) @@ -63,7 +68,7 @@ func NewRequestWithContext(r *http.Request, opts ...RequestOption) (*http.Reques } if fc.logger == nil { - fc.logger = logger + fc.logger = globalLogger } if fc.documentRoot == "" { diff --git a/frankenphp.c b/frankenphp.c index 3d124ece..db49efec 100644 --- a/frankenphp.c +++ b/frankenphp.c @@ -802,7 +802,7 @@ static void frankenphp_register_variables(zval *track_vars_array) { } static void frankenphp_log_message(const char *message, int syslog_type_int) { - go_log((char *)message, syslog_type_int); + go_log(thread_index, (char *)message, syslog_type_int); } static char *frankenphp_getenv(const char *name, size_t name_len) { diff --git a/frankenphp.go b/frankenphp.go index 953dd7b0..388aa301 100644 --- a/frankenphp.go +++ b/frankenphp.go @@ -60,8 +60,10 @@ var ( isRunning bool onServerShutdown []func() - loggerMu sync.RWMutex - logger *slog.Logger + // Set default values to make Shutdown() idempotent + globalMu sync.Mutex + globalCtx = context.Background() + globalLogger = slog.Default() metrics Metrics = nullMetrics{} @@ -231,15 +233,19 @@ func Init(options ...Option) error { } } - if opt.logger == nil { - // set a default logger - // to disable logging, set the logger to slog.New(slog.NewTextHandler(io.Discard, nil)) - opt.logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + globalMu.Lock() + + if opt.ctx != nil { + globalCtx = opt.ctx + opt.ctx = nil } - loggerMu.Lock() - logger = opt.logger - loggerMu.Unlock() + if opt.logger != nil { + globalLogger = opt.logger + opt.logger = nil + } + + globalMu.Unlock() if opt.metrics != nil { metrics = opt.metrics @@ -262,11 +268,16 @@ func Init(options ...Option) error { if config.ZTS { if !config.ZendMaxExecutionTimers && runtime.GOOS == "linux" { - logger.Warn(`Zend Max Execution Timers are not enabled, timeouts (e.g. "max_execution_time") are disabled, recompile PHP with the "--enable-zend-max-execution-timers" configuration option to fix this issue`) + if globalLogger.Enabled(globalCtx, slog.LevelWarn) { + globalLogger.LogAttrs(globalCtx, slog.LevelWarn, `Zend Max Execution Timers are not enabled, timeouts (e.g. "max_execution_time") are disabled, recompile PHP with the "--enable-zend-max-execution-timers" configuration option to fix this issue`) + } } } else { opt.numThreads = 1 - logger.Warn(`ZTS is not enabled, only 1 thread will be available, recompile PHP using the "--enable-zts" configuration option or performance will be degraded`) + + if globalLogger.Enabled(globalCtx, slog.LevelWarn) { + globalLogger.LogAttrs(globalCtx, slog.LevelWarn, `ZTS is not enabled, only 1 thread will be available, recompile PHP using the "--enable-zts" configuration option or performance will be degraded`) + } } mainThread, err := initPHPThreads(opt.numThreads, opt.maxThreads, opt.phpIni) @@ -274,7 +285,7 @@ func Init(options ...Option) error { return err } - regularRequestChan = make(chan *frankenPHPContext, opt.numThreads-workerThreadCount) + regularRequestChan = make(chan contextHolder, opt.numThreads-workerThreadCount) regularThreads = make([]*phpThread, 0, opt.numThreads-workerThreadCount) for i := 0; i < opt.numThreads-workerThreadCount; i++ { convertToRegularThread(getInactivePHPThread()) @@ -286,10 +297,12 @@ func Init(options ...Option) error { initAutoScaling(mainThread) - ctx := context.Background() - logger.LogAttrs(ctx, slog.LevelInfo, "FrankenPHP started 🐘", slog.String("php_version", Version().Version), slog.Int("num_threads", mainThread.numThreads), slog.Int("max_threads", mainThread.maxThreads)) - if EmbeddedAppPath != "" { - logger.LogAttrs(ctx, slog.LevelInfo, "embedded PHP app 📦", slog.String("path", EmbeddedAppPath)) + if globalLogger.Enabled(globalCtx, slog.LevelInfo) { + globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "FrankenPHP started 🐘", slog.String("php_version", Version().Version), slog.Int("num_threads", mainThread.numThreads), slog.Int("max_threads", mainThread.maxThreads)) + + if EmbeddedAppPath != "" { + globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "embedded PHP app 📦", slog.String("path", EmbeddedAppPath)) + } } // register the startup/shutdown hooks (mainly useful for extensions) @@ -329,7 +342,11 @@ func Shutdown() { } isRunning = false - logger.Debug("FrankenPHP shut down") + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "FrankenPHP shut down") + } + + resetGlobals() } // ServeHTTP executes a PHP script according to the given context. @@ -343,7 +360,11 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error return ErrNotRunning } - fc, ok := fromContext(request.Context()) + ctx := request.Context() + fc, ok := fromContext(ctx) + + ch := contextHolder{ctx, fc} + if !ok { return ErrInvalidRequest } @@ -356,16 +377,17 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error // Detect if a worker is available to handle this request if fc.worker != nil { - return fc.worker.handleRequest(fc) + return fc.worker.handleRequest(ch) } // If no worker was available, send the request to non-worker threads - return handleRequestWithRegularPHPThreads(fc) + return handleRequestWithRegularPHPThreads(ch) } //export go_ub_write func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t, C.bool) { - fc := phpThreads[threadIndex].getRequestContext() + thread := phpThreads[threadIndex] + fc := thread.frankenPHPContext() if fc.isDone { return 0, C.bool(true) @@ -380,14 +402,27 @@ func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t, writer = fc.responseWriter } + var ctx context.Context + i, e := writer.Write(unsafe.Slice((*byte)(unsafe.Pointer(cBuf)), length)) if e != nil { - fc.logger.LogAttrs(context.Background(), slog.LevelWarn, "write error", slog.Any("error", e)) + ctx = thread.context() + + if fc.logger.Enabled(ctx, slog.LevelWarn) { + fc.logger.LogAttrs(ctx, slog.LevelWarn, "write error", slog.Any("error", e)) + } } if fc.responseWriter == nil { // probably starting a worker script, log the output - fc.logger.Info(writer.(*bytes.Buffer).String()) + + if ctx == nil { + ctx = thread.context() + } + + if fc.logger.Enabled(ctx, slog.LevelInfo) { + fc.logger.LogAttrs(ctx, slog.LevelInfo, writer.(*bytes.Buffer).String()) + } } return C.size_t(i), C.bool(fc.clientHasClosed()) @@ -396,12 +431,15 @@ func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t, //export go_apache_request_headers func go_apache_request_headers(threadIndex C.uintptr_t) (*C.go_string, C.size_t) { thread := phpThreads[threadIndex] - fc := thread.getRequestContext() + ctx := thread.context() + fc := thread.frankenPHPContext() if fc.responseWriter == nil { // worker mode, not handling a request - logger.LogAttrs(context.Background(), slog.LevelDebug, "apache_request_headers() called in non-HTTP context", slog.String("worker", fc.worker.name)) + if globalLogger.Enabled(ctx, slog.LevelDebug) { + globalLogger.LogAttrs(ctx, slog.LevelDebug, "apache_request_headers() called in non-HTTP context", slog.String("worker", fc.worker.name)) + } return nil, 0 } @@ -429,10 +467,13 @@ func go_apache_request_headers(threadIndex C.uintptr_t) (*C.go_string, C.size_t) return sd, C.size_t(len(fc.request.Header)) } -func addHeader(fc *frankenPHPContext, cString *C.char, length C.int) { +func addHeader(ctx context.Context, fc *frankenPHPContext, cString *C.char, length C.int) { key, val := splitRawHeader(cString, int(length)) if key == "" { - fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "invalid header", slog.String("header", C.GoStringN(cString, length))) + if fc.logger.Enabled(ctx, slog.LevelDebug) { + fc.logger.LogAttrs(ctx, slog.LevelDebug, "invalid header", slog.String("header", C.GoStringN(cString, length))) + } + return } fc.responseWriter.Header().Add(key, val) @@ -471,8 +512,8 @@ func splitRawHeader(rawHeader *C.char, length int) (string, string) { //export go_write_headers func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_llist) C.bool { - fc := phpThreads[threadIndex].getRequestContext() - + thread := phpThreads[threadIndex] + fc := thread.frankenPHPContext() if fc == nil { return C.bool(false) } @@ -490,7 +531,7 @@ func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_lli for current != nil { h := (*C.sapi_header_struct)(unsafe.Pointer(&(current.data))) - addHeader(fc, h.header, C.int(h.header_len)) + addHeader(thread.context(), fc, h.header, C.int(h.header_len)) current = current.next } @@ -499,13 +540,18 @@ func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_lli // go panics on invalid status code // https://github.com/golang/go/blob/9b8742f2e79438b9442afa4c0a0139d3937ea33f/src/net/http/server.go#L1162 if goStatus < 100 || goStatus > 999 { - logger.Warn(fmt.Sprintf("Invalid response status code %v", goStatus)) + ctx := thread.context() + + if globalLogger.Enabled(ctx, slog.LevelWarn) { + globalLogger.LogAttrs(ctx, slog.LevelWarn, "Invalid response status code", slog.Int("status_code", goStatus)) + } + goStatus = 500 } fc.responseWriter.WriteHeader(goStatus) - if goStatus >= 100 && goStatus < 200 { + if goStatus < 200 { // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses h := fc.responseWriter.Header() for k := range h { @@ -518,8 +564,13 @@ func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_lli //export go_sapi_flush func go_sapi_flush(threadIndex C.uintptr_t) bool { - fc := phpThreads[threadIndex].getRequestContext() - if fc == nil || fc.responseWriter == nil { + thread := phpThreads[threadIndex] + fc := thread.frankenPHPContext() + if fc == nil { + return false + } + + if fc.responseWriter == nil { return false } @@ -528,7 +579,11 @@ func go_sapi_flush(threadIndex C.uintptr_t) bool { } if err := http.NewResponseController(fc.responseWriter).Flush(); err != nil { - logger.LogAttrs(context.Background(), slog.LevelWarn, "the current responseWriter is not a flusher, if you are not using a custom build, please report this issue", slog.Any("error", err)) + ctx := thread.context() + + if globalLogger.Enabled(ctx, slog.LevelWarn) { + globalLogger.LogAttrs(ctx, slog.LevelWarn, "the current responseWriter is not a flusher, if you are not using a custom build, please report this issue", slog.Any("error", err)) + } } return false @@ -536,7 +591,7 @@ func go_sapi_flush(threadIndex C.uintptr_t) bool { //export go_read_post func go_read_post(threadIndex C.uintptr_t, cBuf *C.char, countBytes C.size_t) (readBytes C.size_t) { - fc := phpThreads[threadIndex].getRequestContext() + fc := phpThreads[threadIndex].frankenPHPContext() if fc.responseWriter == nil { return 0 @@ -555,7 +610,7 @@ func go_read_post(threadIndex C.uintptr_t, cBuf *C.char, countBytes C.size_t) (r //export go_read_cookies func go_read_cookies(threadIndex C.uintptr_t) *C.char { - request := phpThreads[threadIndex].getRequestContext().request + request := phpThreads[threadIndex].frankenPHPContext().request if request == nil { return nil } @@ -573,7 +628,8 @@ func go_read_cookies(threadIndex C.uintptr_t) *C.char { } //export go_log -func go_log(message *C.char, level C.int) { +func go_log(threadIndex C.uintptr_t, message *C.char, level C.int) { + ctx := phpThreads[threadIndex].context() m := C.GoString(message) var le syslogLevel @@ -585,21 +641,29 @@ func go_log(message *C.char, level C.int) { switch le { case syslogLevelEmerg, syslogLevelAlert, syslogLevelCrit, syslogLevelErr: - logger.LogAttrs(context.Background(), slog.LevelError, m, slog.String("syslog_level", syslogLevel(level).String())) + if globalLogger.Enabled(ctx, slog.LevelError) { + globalLogger.LogAttrs(ctx, slog.LevelError, m, slog.String("syslog_level", syslogLevel(level).String())) + } case syslogLevelWarn: - logger.LogAttrs(context.Background(), slog.LevelWarn, m, slog.String("syslog_level", syslogLevel(level).String())) + if globalLogger.Enabled(ctx, slog.LevelWarn) { + globalLogger.LogAttrs(ctx, slog.LevelWarn, m, slog.String("syslog_level", syslogLevel(level).String())) + } case syslogLevelDebug: - logger.LogAttrs(context.Background(), slog.LevelDebug, m, slog.String("syslog_level", syslogLevel(level).String())) + if globalLogger.Enabled(ctx, slog.LevelDebug) { + globalLogger.LogAttrs(ctx, slog.LevelDebug, m, slog.String("syslog_level", syslogLevel(level).String())) + } default: - logger.LogAttrs(context.Background(), slog.LevelInfo, m, slog.String("syslog_level", syslogLevel(level).String())) + if globalLogger.Enabled(ctx, slog.LevelInfo) { + globalLogger.LogAttrs(ctx, slog.LevelInfo, m, slog.String("syslog_level", syslogLevel(level).String())) + } } } //export go_is_context_done func go_is_context_done(threadIndex C.uintptr_t) C.bool { - return C.bool(phpThreads[threadIndex].getRequestContext().isDone) + return C.bool(phpThreads[threadIndex].frankenPHPContext().isDone) } // ExecuteScriptCLI executes the PHP script passed as parameter. @@ -648,3 +712,11 @@ func timeoutChan(timeout time.Duration) <-chan time.Time { return time.After(timeout) } + +func resetGlobals() { + globalMu.Lock() + globalCtx = context.Background() + globalLogger = slog.Default() + workers = nil + globalMu.Unlock() +} diff --git a/frankenphp_test.go b/frankenphp_test.go index 5450da6c..71379740 100644 --- a/frankenphp_test.go +++ b/frankenphp_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "errors" + "flag" "fmt" "io" "log" @@ -136,6 +137,16 @@ func testPost(url string, body string, handler func(http.ResponseWriter, *http.R return testRequest(req, handler, t) } +func TestMain(m *testing.M) { + flag.Parse() + + if !testing.Verbose() { + slog.SetDefault(slog.New(slog.DiscardHandler)) + } + + os.Exit(m.Run()) +} + func TestHelloWorld_module(t *testing.T) { testHelloWorld(t, nil) } func TestHelloWorld_worker(t *testing.T) { testHelloWorld(t, &testOptions{workerScript: "index.php"}) diff --git a/internal/testserver/main.go b/internal/testserver/main.go index 3e4af37b..249be647 100644 --- a/internal/testserver/main.go +++ b/internal/testserver/main.go @@ -10,8 +10,11 @@ import ( ) func main() { + ctx := context.Background() logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - if err := frankenphp.Init(frankenphp.WithLogger(logger)); err != nil { + + + if err := frankenphp.Init(frankenphp.WithContext(ctx), frankenphp.WithLogger(logger)); err != nil { panic(err) } defer frankenphp.Shutdown() @@ -32,6 +35,9 @@ func main() { port = "8080" } - logger.LogAttrs(context.Background(), slog.LevelError, "server error", slog.Any("error", http.ListenAndServe(":"+port, nil))) + if logger.Enabled(ctx, slog.LevelError) { + logger.LogAttrs(ctx, slog.LevelError, "server error", slog.Any("error", http.ListenAndServe(":"+port, nil))) + } + os.Exit(1) } diff --git a/internal/watcher/watch_pattern.go b/internal/watcher/watch_pattern.go index 5d9f2b63..37b2fdde 100644 --- a/internal/watcher/watch_pattern.go +++ b/internal/watcher/watch_pattern.go @@ -3,7 +3,6 @@ package watcher import ( - "context" "log/slog" "path/filepath" "strings" @@ -81,9 +80,10 @@ func isValidEventType(eventType int) bool { // 0:dir,1:file,2:hard_link,3:sym_link,4:watcher,5:other, func isValidPathType(pathType int, fileName string) bool { - if pathType == 4 { - logger.LogAttrs(context.Background(), slog.LevelDebug, "special edant/watcher event", slog.String("fileName", fileName)) + if pathType == 4 && logger.Enabled(ctx, slog.LevelDebug) { + logger.LogAttrs(ctx, slog.LevelDebug, "special edant/watcher event", slog.String("fileName", fileName)) } + return pathType <= 2 } @@ -163,9 +163,14 @@ func matchPattern(pattern string, fileName string) bool { if pattern == "" { return true } + patternMatches, err := filepath.Match(pattern, fileName) + if err != nil { - logger.LogAttrs(context.Background(), slog.LevelError, "failed to match filename", slog.String("file", fileName), slog.Any("error", err)) + if logger.Enabled(ctx, slog.LevelError) { + logger.LogAttrs(ctx, slog.LevelError, "failed to match filename", slog.String("file", fileName), slog.Any("error", err)) + } + return false } diff --git a/internal/watcher/watcher-skip.go b/internal/watcher/watcher-skip.go index 9dd24112..f801e3e1 100644 --- a/internal/watcher/watcher-skip.go +++ b/internal/watcher/watcher-skip.go @@ -2,9 +2,12 @@ package watcher -import "log/slog" +import ( + "context" + "log/slog" +) -func InitWatcher(filePatterns []string, callback func(), logger *slog.Logger) error { +func InitWatcher(ct context.Context, filePatterns []string, callback func(), logger *slog.Logger) error { logger.Error("watcher support is not enabled") return nil diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 5236cd3d..798b31b9 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -43,11 +43,13 @@ var ( activeWatcher *watcher // after stopping the watcher we will wait for eventual reloads to finish reloadWaitGroup sync.WaitGroup + // we are passing the context from the main package to the watcher + ctx context.Context // we are passing the logger from the main package to the watcher logger *slog.Logger ) -func InitWatcher(filePatterns []string, callback func(), slogger *slog.Logger) error { +func InitWatcher(ct context.Context, filePatterns []string, callback func(), slogger *slog.Logger) error { if len(filePatterns) == 0 { return nil } @@ -55,9 +57,10 @@ func InitWatcher(filePatterns []string, callback func(), slogger *slog.Logger) e return ErrAlreadyStarted } watcherIsActive.Store(true) + ctx = ct logger = slogger activeWatcher = &watcher{callback: callback} - err := activeWatcher.startWatching(filePatterns) + err := activeWatcher.startWatching(ctx, filePatterns) if err != nil { return err } @@ -71,7 +74,11 @@ func DrainWatcher() { return } watcherIsActive.Store(false) - logger.Debug("stopping watcher") + + if logger.Enabled(ctx, slog.LevelDebug) { + logger.LogAttrs(ctx, slog.LevelDebug, "stopping watcher") + } + activeWatcher.stopWatching() reloadWaitGroup.Wait() activeWatcher = nil @@ -79,15 +86,19 @@ func DrainWatcher() { // TODO: how to test this? func retryWatching(watchPattern *watchPattern) { - ctx := context.Background() - failureMu.Lock() defer failureMu.Unlock() if watchPattern.failureCount >= maxFailureCount { - logger.LogAttrs(ctx, slog.LevelWarn, "giving up watching", slog.String("dir", watchPattern.dir)) + if logger.Enabled(ctx, slog.LevelWarn) { + logger.LogAttrs(ctx, slog.LevelWarn, "giving up watching", slog.String("dir", watchPattern.dir)) + } + return } - logger.LogAttrs(ctx, slog.LevelInfo, "watcher was closed prematurely, retrying...", slog.String("dir", watchPattern.dir)) + + if logger.Enabled(ctx, slog.LevelInfo) { + logger.LogAttrs(ctx, slog.LevelInfo, "watcher was closed prematurely, retrying...", slog.String("dir", watchPattern.dir)) + } watchPattern.failureCount++ session, err := startSession(watchPattern) @@ -106,7 +117,7 @@ func retryWatching(watchPattern *watchPattern) { }() } -func (w *watcher) startWatching(filePatterns []string) error { +func (w *watcher) startWatching(ctx context.Context, filePatterns []string) error { w.trigger = make(chan string) w.stop = make(chan struct{}) w.sessions = make([]C.uintptr_t, len(filePatterns)) @@ -134,26 +145,29 @@ func (w *watcher) stopWatching() { } func startSession(w *watchPattern) (C.uintptr_t, error) { - ctx := context.Background() - handle := cgo.NewHandle(w) cDir := C.CString(w.dir) defer C.free(unsafe.Pointer(cDir)) watchSession := C.start_new_watcher(cDir, C.uintptr_t(handle)) if watchSession != 0 { - logger.LogAttrs(ctx, slog.LevelDebug, "watching", slog.String("dir", w.dir), slog.Any("patterns", w.patterns)) + if logger.Enabled(ctx, slog.LevelDebug) { + logger.LogAttrs(ctx, slog.LevelDebug, "watching", slog.String("dir", w.dir), slog.Any("patterns", w.patterns)) + } return watchSession, nil } - logger.LogAttrs(ctx, slog.LevelError, "couldn't start watching", slog.String("dir", w.dir)) + + if logger.Enabled(ctx, slog.LevelError) { + logger.LogAttrs(ctx, slog.LevelError, "couldn't start watching", slog.String("dir", w.dir)) + } return watchSession, ErrUnableToStartWatching } func stopSession(session C.uintptr_t) { success := C.stop_watcher(session) - if success == 0 { - logger.Warn("couldn't close the watcher") + if success == 0 && logger.Enabled(ctx, slog.LevelWarn) { + logger.LogAttrs(ctx, slog.LevelWarn, "couldn't close the watcher") } } @@ -195,7 +209,11 @@ func listenForFileEvents(triggerWatcher chan string, stopWatcher chan struct{}) timer.Reset(debounceDuration) case <-timer.C: timer.Stop() - logger.LogAttrs(context.Background(), slog.LevelInfo, "filesystem change detected", slog.String("file", lastChangedFile)) + + if logger.Enabled(ctx, slog.LevelInfo) { + logger.LogAttrs(ctx, slog.LevelInfo, "filesystem change detected", slog.String("file", lastChangedFile)) + } + scheduleReload() } } diff --git a/options.go b/options.go index 9d58125c..abf16f0f 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,7 @@ package frankenphp import ( + "context" "fmt" "log/slog" "time" @@ -19,6 +20,7 @@ type WorkerOption func(*workerOpt) error // // If you change this, also update the Caddy module and the documentation. type opt struct { + ctx context.Context numThreads int maxThreads int workers []workerOpt @@ -42,6 +44,15 @@ type workerOpt struct { onServerShutdown func() } +// WithContext sets the main context to use. +func WithContext(ctx context.Context) Option { + return func(h *opt) error { + h.ctx = ctx + + return nil + } +} + // WithNumThreads configures the number of PHP threads to start. func WithNumThreads(numThreads int) Option { return func(o *opt) error { diff --git a/phpmainthread.go b/phpmainthread.go index 3154bb77..40864206 100644 --- a/phpmainthread.go +++ b/phpmainthread.go @@ -8,7 +8,6 @@ package frankenphp // #include "frankenphp.h" import "C" import ( - "context" "log/slog" "strings" "sync" @@ -171,7 +170,9 @@ func (mainThread *phpMainThread) setAutomaticMaxThreads() { maxAllowedThreads := totalSysMemory / uint64(perThreadMemoryLimit) mainThread.maxThreads = int(maxAllowedThreads) - logger.LogAttrs(context.Background(), slog.LevelDebug, "Automatic thread limit", slog.Int("perThreadMemoryLimitMB", int(perThreadMemoryLimit/1024/1024)), slog.Int("maxThreads", mainThread.maxThreads)) + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "Automatic thread limit", slog.Int("perThreadMemoryLimitMB", int(perThreadMemoryLimit/1024/1024)), slog.Int("maxThreads", mainThread.maxThreads)) + } } //export go_frankenphp_shutdown_main_thread diff --git a/phpmainthread_test.go b/phpmainthread_test.go index b2d37b35..edbd7b3a 100644 --- a/phpmainthread_test.go +++ b/phpmainthread_test.go @@ -18,8 +18,15 @@ import ( var testDataPath, _ = filepath.Abs("./testdata") +func setupGlobals(t *testing.T) { + t.Helper() + + t.Cleanup(Shutdown) + + resetGlobals() +} + func TestStartAndStopTheMainThreadWithOneInactiveThread(t *testing.T) { - logger = slog.New(slog.NewTextHandler(io.Discard, nil)) _, err := initPHPThreads(1, 1, nil) // boot 1 thread assert.NoError(t, err) @@ -28,12 +35,13 @@ func TestStartAndStopTheMainThreadWithOneInactiveThread(t *testing.T) { assert.True(t, phpThreads[0].state.is(stateInactive)) drainPHPThreads() + assert.Nil(t, phpThreads) } func TestTransitionRegularThreadToWorkerThread(t *testing.T) { - workers = nil - logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + setupGlobals(t) + _, err := initPHPThreads(1, 1, nil) assert.NoError(t, err) @@ -42,7 +50,7 @@ func TestTransitionRegularThreadToWorkerThread(t *testing.T) { assert.IsType(t, ®ularThread{}, phpThreads[0].handler) // transition to worker thread - worker := getDummyWorker("transition-worker-1.php") + worker := getDummyWorker(t, "transition-worker-1.php") convertToWorkerThread(phpThreads[0], worker) assert.IsType(t, &workerThread{}, phpThreads[0].handler) assert.Len(t, worker.threads, 1) @@ -57,12 +65,12 @@ func TestTransitionRegularThreadToWorkerThread(t *testing.T) { } func TestTransitionAThreadBetween2DifferentWorkers(t *testing.T) { - workers = nil - logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + setupGlobals(t) + _, err := initPHPThreads(1, 1, nil) assert.NoError(t, err) - firstWorker := getDummyWorker("transition-worker-1.php") - secondWorker := getDummyWorker("transition-worker-2.php") + firstWorker := getDummyWorker(t, "transition-worker-1.php") + secondWorker := getDummyWorker(t, "transition-worker-2.php") // convert to first worker thread convertToWorkerThread(phpThreads[0], firstWorker) @@ -151,13 +159,13 @@ func TestTransitionThreadsWhileDoingRequests(t *testing.T) { } func TestFinishBootingAWorkerScript(t *testing.T) { - workers = nil - logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + setupGlobals(t) + _, err := initPHPThreads(1, 1, nil) assert.NoError(t, err) // boot the worker - worker := getDummyWorker("transition-worker-1.php") + worker := getDummyWorker(t, "transition-worker-1.php") convertToWorkerThread(phpThreads[0], worker) phpThreads[0].state.waitFor(stateReady) @@ -193,16 +201,20 @@ func TestReturnAnErrorIf2ModuleWorkersHaveTheSameName(t *testing.T) { assert.Error(t, err2, "two workers cannot have the same name") } -func getDummyWorker(fileName string) *worker { +func getDummyWorker(t *testing.T, fileName string) *worker { + t.Helper() + if workers == nil { workers = []*worker{} } + worker, _ := newWorker(workerOpt{ fileName: testDataPath + "/" + fileName, num: 1, maxConsecutiveFailures: defaultMaxConsecutiveFailures, }) workers = append(workers, worker) + return worker } diff --git a/phpthread.go b/phpthread.go index a60aa8f0..ec807be9 100644 --- a/phpthread.go +++ b/phpthread.go @@ -5,7 +5,6 @@ package frankenphp import "C" import ( "context" - "log/slog" "runtime" "sync" "unsafe" @@ -16,7 +15,7 @@ import ( type phpThread struct { runtime.Pinner threadIndex int - requestChan chan *frankenPHPContext + requestChan chan contextHolder drainChan chan struct{} handlerMu sync.Mutex handler threadHandler @@ -29,13 +28,14 @@ type threadHandler interface { name() string beforeScriptExecution() string afterScriptExecution(exitStatus int) - getRequestContext() *frankenPHPContext + context() context.Context + frankenPHPContext() *frankenPHPContext } func newPHPThread(threadIndex int) *phpThread { return &phpThread{ threadIndex: threadIndex, - requestChan: make(chan *frankenPHPContext), + requestChan: make(chan contextHolder), state: newThreadState(), } } @@ -44,7 +44,6 @@ func newPHPThread(threadIndex int) *phpThread { func (thread *phpThread) boot() { // thread must be in reserved state to boot if !thread.state.compareAndSwap(stateReserved, stateBooting) && !thread.state.compareAndSwap(stateBootRequested, stateBooting) { - logger.Error("thread is not in reserved state: " + thread.state.name()) panic("thread is not in reserved state: " + thread.state.name()) } @@ -56,7 +55,6 @@ func (thread *phpThread) boot() { // start the actual posix thread - TODO: try this with go threads instead if !C.frankenphp_new_php_thread(C.uintptr_t(thread.threadIndex)) { - logger.LogAttrs(context.Background(), slog.LevelError, "unable to create thread", slog.Int("thread", thread.threadIndex)) panic("unable to create thread") } @@ -100,12 +98,17 @@ func (thread *phpThread) setHandler(handler threadHandler) { func (thread *phpThread) transitionToNewHandler() string { thread.state.set(stateTransitionInProgress) thread.state.waitFor(stateTransitionComplete) + // execute beforeScriptExecution of the new handler return thread.handler.beforeScriptExecution() } -func (thread *phpThread) getRequestContext() *frankenPHPContext { - return thread.handler.getRequestContext() +func (thread *phpThread) frankenPHPContext() *frankenPHPContext { + return thread.handler.frankenPHPContext() +} + +func (thread *phpThread) context() context.Context { + return thread.handler.context() } func (thread *phpThread) name() string { diff --git a/scaling.go b/scaling.go index 57e6c598..3f541c27 100644 --- a/scaling.go +++ b/scaling.go @@ -4,7 +4,6 @@ package frankenphp //#include import "C" import ( - "context" "errors" "log/slog" "sync" @@ -54,7 +53,11 @@ func initAutoScaling(mainThread *phpMainThread) { func drainAutoScaling() { scalingMu.Lock() - logger.LogAttrs(context.Background(), slog.LevelDebug, "shutting down autoscaling", slog.Int("autoScaledThreads", len(autoScaledThreads))) + + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "shutting down autoscaling", slog.Int("autoScaledThreads", len(autoScaledThreads))) + } + scalingMu.Unlock() } @@ -94,13 +97,18 @@ func scaleWorkerThread(worker *worker) { thread, err := addWorkerThread(worker) if err != nil { - logger.LogAttrs(context.Background(), slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.String("worker", worker.name), slog.Any("error", err)) + if globalLogger.Enabled(globalCtx, slog.LevelWarn) { + globalLogger.LogAttrs(globalCtx, slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.String("worker", worker.name), slog.Any("error", err)) + } + return } autoScaledThreads = append(autoScaledThreads, thread) - logger.LogAttrs(context.Background(), slog.LevelInfo, "upscaling worker thread", slog.String("worker", worker.name), slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads))) + if globalLogger.Enabled(globalCtx, slog.LevelInfo) { + globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "upscaling worker thread", slog.String("worker", worker.name), slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads))) + } } // scaleRegularThread adds a regular PHP thread automatically @@ -119,13 +127,18 @@ func scaleRegularThread() { thread, err := addRegularThread() if err != nil { - logger.LogAttrs(context.Background(), slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.Any("error", err)) + if globalLogger.Enabled(globalCtx, slog.LevelWarn) { + globalLogger.LogAttrs(globalCtx, slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.Any("error", err)) + } + return } autoScaledThreads = append(autoScaledThreads, thread) - logger.LogAttrs(context.Background(), slog.LevelInfo, "upscaling regular thread", slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads))) + if globalLogger.Enabled(globalCtx, slog.LevelInfo) { + globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "upscaling regular thread", slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads))) + } } func startUpscalingThreads(maxScaledThreads int, scale chan *frankenPHPContext, done chan struct{}) { @@ -204,7 +217,10 @@ func deactivateThreads() { convertToInactiveThread(thread) stoppedThreadCount++ autoScaledThreads = append(autoScaledThreads[:i], autoScaledThreads[i+1:]...) - logger.LogAttrs(context.Background(), slog.LevelInfo, "downscaling thread", slog.Int("thread", thread.threadIndex), slog.Int64("wait_time", waitTime), slog.Int("num_threads", len(autoScaledThreads))) + + if globalLogger.Enabled(globalCtx, slog.LevelInfo) { + globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "downscaling thread", slog.Int("thread", thread.threadIndex), slog.Int64("wait_time", waitTime), slog.Int("num_threads", len(autoScaledThreads))) + } continue } diff --git a/threadinactive.go b/threadinactive.go index 912d339f..50172e57 100644 --- a/threadinactive.go +++ b/threadinactive.go @@ -1,5 +1,7 @@ package frankenphp +import "context" + // representation of a thread with no work assigned to it // implements the threadHandler interface // each inactive thread weighs around ~350KB @@ -18,6 +20,7 @@ func (handler *inactiveThread) beforeScriptExecution() string { switch thread.state.get() { case stateTransitionRequested: return thread.transitionToNewHandler() + case stateBooting, stateTransitionComplete: thread.state.set(stateInactive) @@ -25,11 +28,14 @@ func (handler *inactiveThread) beforeScriptExecution() string { thread.state.markAsWaiting(true) thread.state.waitFor(stateTransitionRequested, stateShuttingDown) thread.state.markAsWaiting(false) + return handler.beforeScriptExecution() + case stateShuttingDown: // signal to stop return "" } + panic("unexpected state: " + thread.state.name()) } @@ -37,7 +43,11 @@ func (handler *inactiveThread) afterScriptExecution(int) { panic("inactive threads should not execute scripts") } -func (handler *inactiveThread) getRequestContext() *frankenPHPContext { +func (handler *inactiveThread) frankenPHPContext() *frankenPHPContext { + return nil +} + +func (handler *inactiveThread) context() context.Context { return nil } diff --git a/threadregular.go b/threadregular.go index 64accc60..b01c060d 100644 --- a/threadregular.go +++ b/threadregular.go @@ -1,6 +1,7 @@ package frankenphp import ( + "context" "sync" ) @@ -8,15 +9,16 @@ import ( // executes PHP scripts in a web context // implements the threadHandler interface type regularThread struct { - state *threadState - thread *phpThread - requestContext *frankenPHPContext + contextHolder + + state *threadState + thread *phpThread } var ( regularThreads []*phpThread regularThreadMu = &sync.RWMutex{} - regularRequestChan chan *frankenPHPContext + regularRequestChan chan contextHolder ) func convertToRegularThread(thread *phpThread) { @@ -33,25 +35,33 @@ func (handler *regularThread) beforeScriptExecution() string { case stateTransitionRequested: detachRegularThread(handler.thread) return handler.thread.transitionToNewHandler() + case stateTransitionComplete: handler.state.set(stateReady) return handler.waitForRequest() + case stateReady: return handler.waitForRequest() + case stateShuttingDown: detachRegularThread(handler.thread) // signal to stop return "" } + panic("unexpected state: " + handler.state.name()) } -func (handler *regularThread) afterScriptExecution(int) { +func (handler *regularThread) afterScriptExecution(_ int) { handler.afterRequest() } -func (handler *regularThread) getRequestContext() *frankenPHPContext { - return handler.requestContext +func (handler *regularThread) frankenPHPContext() *frankenPHPContext { + return handler.contextHolder.frankenPHPContext +} + +func (handler *regularThread) context() context.Context { + return handler.ctx } func (handler *regularThread) name() string { @@ -64,32 +74,36 @@ func (handler *regularThread) waitForRequest() string { handler.state.markAsWaiting(true) - var fc *frankenPHPContext + var ch contextHolder + select { case <-handler.thread.drainChan: // go back to beforeScriptExecution return handler.beforeScriptExecution() - case fc = <-regularRequestChan: + case ch = <-regularRequestChan: } - handler.requestContext = fc + handler.ctx = ch.ctx + handler.contextHolder.frankenPHPContext = ch.frankenPHPContext handler.state.markAsWaiting(false) // set the scriptFilename that should be executed - return fc.scriptFilename + return handler.contextHolder.frankenPHPContext.scriptFilename } func (handler *regularThread) afterRequest() { - handler.requestContext.closeContext() - handler.requestContext = nil + handler.contextHolder.frankenPHPContext.closeContext() + handler.contextHolder.frankenPHPContext = nil + handler.ctx = nil } -func handleRequestWithRegularPHPThreads(fc *frankenPHPContext) error { +func handleRequestWithRegularPHPThreads(ch contextHolder) error { metrics.StartRequest() + select { - case regularRequestChan <- fc: + case regularRequestChan <- ch: // a thread was available to handle the request immediately - <-fc.done + <-ch.frankenPHPContext.done metrics.StopRequest() return nil @@ -101,19 +115,19 @@ func handleRequestWithRegularPHPThreads(fc *frankenPHPContext) error { metrics.QueuedRequest() for { select { - case regularRequestChan <- fc: + case regularRequestChan <- ch: metrics.DequeuedRequest() - <-fc.done + <-ch.frankenPHPContext.done metrics.StopRequest() return nil - case scaleChan <- fc: + case scaleChan <- ch.frankenPHPContext: // the request has triggered scaling, continue to wait for a thread case <-timeoutChan(maxWaitTime): // the request has timed out stalling metrics.DequeuedRequest() - fc.reject(ErrMaxWaitTimeExceeded) + ch.frankenPHPContext.reject(ErrMaxWaitTimeExceeded) return ErrMaxWaitTimeExceeded } diff --git a/threadtasks_test.go b/threadtasks_test.go index d81c5553..24668da5 100644 --- a/threadtasks_test.go +++ b/threadtasks_test.go @@ -1,6 +1,7 @@ package frankenphp import ( + "context" "sync" ) @@ -60,11 +61,15 @@ func (handler *taskThread) beforeScriptExecution() string { panic("unexpected state: " + thread.state.name()) } -func (handler *taskThread) afterScriptExecution(int) { +func (handler *taskThread) afterScriptExecution(_ int) { panic("task threads should not execute scripts") } -func (handler *taskThread) getRequestContext() *frankenPHPContext { +func (handler *taskThread) frankenPHPContext() *frankenPHPContext { + return nil +} + +func (handler *taskThread) context() context.Context { return nil } diff --git a/threadworker.go b/threadworker.go index 5a59f927..09edc6ea 100644 --- a/threadworker.go +++ b/threadworker.go @@ -4,7 +4,6 @@ package frankenphp import "C" import ( "context" - "fmt" "log/slog" "path/filepath" "time" @@ -15,13 +14,15 @@ import ( // executes the PHP worker script in a loop // implements the threadHandler interface type workerThread struct { - state *threadState - thread *phpThread - worker *worker - dummyContext *frankenPHPContext - workerContext *frankenPHPContext - backoff *exponentialBackoff - isBootingScript bool // true if the worker has not reached frankenphp_handle_request yet + state *threadState + thread *phpThread + worker *worker + dummyFrankenPHPContext *frankenPHPContext + dummyContext context.Context + workerFrankenPHPContext *frankenPHPContext + workerContext context.Context + backoff *exponentialBackoff + isBootingScript bool // true if the worker has not reached frankenphp_handle_request yet } func convertToWorkerThread(thread *phpThread, worker *worker) { @@ -58,16 +59,20 @@ func (handler *workerThread) beforeScriptExecution() string { if handler.worker.onThreadReady != nil { handler.worker.onThreadReady(handler.thread.threadIndex) } + setupWorkerScript(handler, handler.worker) + return handler.worker.fileName case stateShuttingDown: if handler.worker.onThreadShutdown != nil { handler.worker.onThreadShutdown(handler.thread.threadIndex) } handler.worker.detachThread(handler.thread) + // signal to stop return "" } + panic("unexpected state: " + handler.state.name()) } @@ -75,7 +80,14 @@ func (handler *workerThread) afterScriptExecution(exitStatus int) { tearDownWorkerScript(handler, exitStatus) } -func (handler *workerThread) getRequestContext() *frankenPHPContext { +func (handler *workerThread) frankenPHPContext() *frankenPHPContext { + if handler.workerFrankenPHPContext != nil { + return handler.workerFrankenPHPContext + } + + return handler.dummyFrankenPHPContext +} +func (handler *workerThread) context() context.Context { if handler.workerContext != nil { return handler.workerContext } @@ -105,23 +117,29 @@ func setupWorkerScript(handler *workerThread, worker *worker) { panic(err) } + ctx := context.WithValue(globalCtx, contextKey, fc) + fc.worker = worker - handler.dummyContext = fc + handler.dummyFrankenPHPContext = fc + handler.dummyContext = ctx handler.isBootingScript = true clearSandboxedEnv(handler.thread) - logger.LogAttrs(context.Background(), slog.LevelDebug, "starting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex)) + + if globalLogger.Enabled(ctx, slog.LevelDebug) { + globalLogger.LogAttrs(ctx, slog.LevelDebug, "starting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex)) + } } func tearDownWorkerScript(handler *workerThread, exitStatus int) { worker := handler.worker + handler.dummyFrankenPHPContext = nil handler.dummyContext = nil - ctx := context.Background() - // if the worker request is not nil, the script might have crashed // make sure to close the worker request context - if handler.workerContext != nil { - handler.workerContext.closeContext() + if handler.workerFrankenPHPContext != nil { + handler.workerFrankenPHPContext.closeContext() + handler.workerFrankenPHPContext = nil handler.workerContext = nil } @@ -129,7 +147,10 @@ func tearDownWorkerScript(handler *workerThread, exitStatus int) { if exitStatus == 0 && !handler.isBootingScript { metrics.StopWorker(worker.name, StopReasonRestart) handler.backoff.recordSuccess() - logger.LogAttrs(ctx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus)) + + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus)) + } return } @@ -139,20 +160,26 @@ func tearDownWorkerScript(handler *workerThread, exitStatus int) { if !handler.isBootingScript { // fatal error (could be due to exit(1), timeouts, etc.) - logger.LogAttrs(ctx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus)) + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus)) + } return } - logger.LogAttrs(ctx, slog.LevelError, "worker script has not reached frankenphp_handle_request()", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex)) + if globalLogger.Enabled(globalCtx, slog.LevelError) { + globalLogger.LogAttrs(globalCtx, slog.LevelError, "worker script has not reached frankenphp_handle_request()", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex)) + } // panic after exponential backoff if the worker has never reached frankenphp_handle_request if handler.backoff.recordFailure() { if !watcherIsEnabled && !handler.state.is(stateReady) { - logger.LogAttrs(ctx, slog.LevelError, "too many consecutive worker failures", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("failures", handler.backoff.failureCount)) panic("too many consecutive worker failures") } - logger.LogAttrs(ctx, slog.LevelWarn, "many consecutive worker failures", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("failures", handler.backoff.failureCount)) + + if globalLogger.Enabled(globalCtx, slog.LevelWarn) { + globalLogger.LogAttrs(globalCtx, slog.LevelWarn, "many consecutive worker failures", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("failures", handler.backoff.failureCount)) + } } } @@ -161,8 +188,9 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) { // unpin any memory left over from previous requests handler.thread.Unpin() - ctx := context.Background() - logger.LogAttrs(ctx, slog.LevelDebug, "waiting for request", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "waiting for request", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) + } // Clear the first dummy request created to initialize the worker if handler.isBootingScript { @@ -182,10 +210,12 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) { handler.state.markAsWaiting(true) - var fc *frankenPHPContext + var requestCH contextHolder select { case <-handler.thread.drainChan: - logger.LogAttrs(ctx, slog.LevelDebug, "shutting down", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) + if globalLogger.Enabled(globalCtx, slog.LevelDebug) { + globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "shutting down", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) + } // flush the opcache when restarting due to watcher or admin api // note: this is done right before frankenphp_handle_request() returns 'false' @@ -194,20 +224,23 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) { } return false, nil - case fc = <-handler.thread.requestChan: - case fc = <-handler.worker.requestChan: + case requestCH = <-handler.thread.requestChan: + case requestCH = <-handler.worker.requestChan: } - handler.workerContext = fc + handler.workerContext = requestCH.ctx + handler.workerFrankenPHPContext = requestCH.frankenPHPContext handler.state.markAsWaiting(false) - if fc.request == nil { - logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) - } else { - logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex), slog.String("url", fc.request.RequestURI)) + if globalLogger.Enabled(requestCH.ctx, slog.LevelDebug) { + if handler.workerFrankenPHPContext.request == nil { + globalLogger.LogAttrs(requestCH.ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) + } else { + globalLogger.LogAttrs(requestCH.ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex), slog.String("url", handler.workerFrankenPHPContext.request.RequestURI)) + } } - return true, fc.handlerParameters + return true, handler.workerFrankenPHPContext.handlerParameters } // go_frankenphp_worker_handle_request_start is called at the start of every php request served. @@ -240,23 +273,28 @@ func go_frankenphp_worker_handle_request_start(threadIndex C.uintptr_t) (C.bool, //export go_frankenphp_finish_worker_request func go_frankenphp_finish_worker_request(threadIndex C.uintptr_t, retval *C.zval) { thread := phpThreads[threadIndex] - fc := thread.getRequestContext() + ctx := thread.context() + fc := ctx.Value(contextKey).(*frankenPHPContext) + if retval != nil { r, err := GoValue[any](unsafe.Pointer(retval)) - if err != nil { - logger.Error(fmt.Sprintf("cannot convert return value: %s", err)) + if err != nil && globalLogger.Enabled(ctx, slog.LevelError) { + globalLogger.LogAttrs(ctx, slog.LevelError, "cannot convert return value", slog.Any("error", err), slog.Int("thread", thread.threadIndex)) } fc.handlerReturn = r } fc.closeContext() + thread.handler.(*workerThread).workerFrankenPHPContext = nil thread.handler.(*workerThread).workerContext = nil - if fc.request == nil { - fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex)) - } else { - fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) + if globalLogger.Enabled(ctx, slog.LevelDebug) { + if fc.request == nil { + fc.logger.LogAttrs(ctx, slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex)) + } else { + fc.logger.LogAttrs(ctx, slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) + } } } @@ -265,9 +303,12 @@ func go_frankenphp_finish_worker_request(threadIndex C.uintptr_t, retval *C.zval //export go_frankenphp_finish_php_request func go_frankenphp_finish_php_request(threadIndex C.uintptr_t) { thread := phpThreads[threadIndex] - fc := thread.getRequestContext() + fc := thread.frankenPHPContext() fc.closeContext() - fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) + ctx := thread.context() + if fc.logger.Enabled(ctx, slog.LevelDebug) { + fc.logger.LogAttrs(ctx, slog.LevelDebug, "request handling finished", slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) + } } diff --git a/types_test.go b/types_test.go index 122fe930..9c2893f6 100644 --- a/types_test.go +++ b/types_test.go @@ -13,7 +13,7 @@ import ( // this is necessary if tests make use of PHP's internal allocation func testOnDummyPHPThread(t *testing.T, test func()) { t.Helper() - logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + globalLogger = slog.New(slog.NewTextHandler(io.Discard, nil)) _, err := initPHPThreads(1, 1, nil) // boot 1 thread assert.NoError(t, err) handler := convertToTaskThread(phpThreads[0]) diff --git a/worker.go b/worker.go index cd592f6e..0a1dbda1 100644 --- a/worker.go +++ b/worker.go @@ -19,7 +19,7 @@ type worker struct { fileName string num int env PreparedEnv - requestChan chan *frankenPHPContext + requestChan chan contextHolder threads []*phpThread threadMutex sync.RWMutex allowPathMatching bool @@ -66,7 +66,7 @@ func initWorkers(opt []workerOpt) error { } watcherIsEnabled = true - if err := watcher.InitWatcher(directoriesToWatch, RestartWorkers, logger); err != nil { + if err := watcher.InitWatcher(globalCtx, directoriesToWatch, RestartWorkers, globalLogger); err != nil { return err } @@ -128,7 +128,7 @@ func newWorker(o workerOpt) (*worker, error) { fileName: absFileName, num: o.num, env: o.env, - requestChan: make(chan *frankenPHPContext), + requestChan: make(chan contextHolder), threads: make([]*phpThread, 0, o.num), allowPathMatching: allowPathMatching, maxConsecutiveFailures: o.maxConsecutiveFailures, @@ -228,17 +228,17 @@ func (worker *worker) countThreads() int { return l } -func (worker *worker) handleRequest(fc *frankenPHPContext) error { +func (worker *worker) handleRequest(ch contextHolder) error { metrics.StartWorkerRequest(worker.name) // dispatch requests to all worker threads in order worker.threadMutex.RLock() for _, thread := range worker.threads { select { - case thread.requestChan <- fc: + case thread.requestChan <- ch: worker.threadMutex.RUnlock() - <-fc.done - metrics.StopWorkerRequest(worker.name, time.Since(fc.startedAt)) + <-ch.frankenPHPContext.done + metrics.StopWorkerRequest(worker.name, time.Since(ch.frankenPHPContext.startedAt)) return nil default: @@ -251,19 +251,19 @@ func (worker *worker) handleRequest(fc *frankenPHPContext) error { metrics.QueuedWorkerRequest(worker.name) for { select { - case worker.requestChan <- fc: + case worker.requestChan <- ch: metrics.DequeuedWorkerRequest(worker.name) - <-fc.done - metrics.StopWorkerRequest(worker.name, time.Since(fc.startedAt)) + <-ch.frankenPHPContext.done + metrics.StopWorkerRequest(worker.name, time.Since(ch.frankenPHPContext.startedAt)) return nil - case scaleChan <- fc: + case scaleChan <- ch.frankenPHPContext: // the request has triggered scaling, continue to wait for a thread case <-timeoutChan(maxWaitTime): // the request has timed out stalling metrics.DequeuedWorkerRequest(worker.name) - fc.reject(ErrMaxWaitTimeExceeded) + ch.frankenPHPContext.reject(ErrMaxWaitTimeExceeded) return ErrMaxWaitTimeExceeded } diff --git a/workerextension.go b/workerextension.go index 49334685..82b74631 100644 --- a/workerextension.go +++ b/workerextension.go @@ -1,6 +1,7 @@ package frankenphp import ( + "context" "net/http" ) @@ -10,7 +11,7 @@ type Workers interface { // The generated HTTP response will be written through the provided writer. SendRequest(rw http.ResponseWriter, r *http.Request) error // SendMessage calls the closure passed to frankenphp_handle_request(), passes message as a parameter, and returns the value produced by the closure. - SendMessage(message any, rw http.ResponseWriter) (any, error) + SendMessage(ctx context.Context, message any, rw http.ResponseWriter) (any, error) // NumThreads returns the number of available threads. NumThreads() int } @@ -43,14 +44,14 @@ func (w *extensionWorkers) NumThreads() int { } // EXPERIMENTAL: SendMessage sends a message to the worker and waits for a response. -func (w *extensionWorkers) SendMessage(message any, rw http.ResponseWriter) (any, error) { +func (w *extensionWorkers) SendMessage(ctx context.Context, message any, rw http.ResponseWriter) (any, error) { fc := newFrankenPHPContext() - fc.logger = logger + fc.logger = globalLogger fc.worker = w.internalWorker fc.responseWriter = rw fc.handlerParameters = message - err := w.internalWorker.handleRequest(fc) + err := w.internalWorker.handleRequest(contextHolder{context.WithValue(ctx, contextKey, fc), fc}) return fc.handlerReturn, err } diff --git a/workerextension_test.go b/workerextension_test.go index 1719cb03..b1c3dd0b 100644 --- a/workerextension_test.go +++ b/workerextension_test.go @@ -69,7 +69,7 @@ func TestWorkerExtensionSendMessage(t *testing.T) { require.NoError(t, err) t.Cleanup(Shutdown) - ret, err := externalWorker.SendMessage("Hello Workers", nil) + ret, err := externalWorker.SendMessage(t.Context(), "Hello Workers", nil) require.NoError(t, err) assert.Equal(t, "received message: Hello Workers", ret)