refactor: rely on context.Context for log/slog and others (#1969)

* refactor: rely on context.Context for log/slog and others

* optimize

* refactor

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix watcher-skip

* better globals handling

* fix

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Kévin Dunglas
2025-11-17 16:32:23 +01:00
committed by GitHub
parent 40cb42aace
commit 8341cc98c6
23 changed files with 425 additions and 183 deletions

View File

@@ -1,6 +1,7 @@
package caddy package caddy
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -56,6 +57,7 @@ type FrankenPHPApp struct {
MaxWaitTime time.Duration `json:"max_wait_time,omitempty"` MaxWaitTime time.Duration `json:"max_wait_time,omitempty"`
metrics frankenphp.Metrics metrics frankenphp.Metrics
ctx context.Context
logger *slog.Logger logger *slog.Logger
} }
@@ -71,6 +73,7 @@ func (f FrankenPHPApp) CaddyModule() caddy.ModuleInfo {
// Provision sets up the module. // Provision sets up the module.
func (f *FrankenPHPApp) Provision(ctx caddy.Context) error { func (f *FrankenPHPApp) Provision(ctx caddy.Context) error {
f.ctx = ctx
f.logger = ctx.Slogger() f.logger = ctx.Slogger()
if httpApp, err := ctx.AppIfConfigured("http"); err == nil { if httpApp, err := ctx.AppIfConfigured("http"); err == nil {
@@ -128,9 +131,10 @@ func (f *FrankenPHPApp) Start() error {
repl := caddy.NewReplacer() repl := caddy.NewReplacer()
opts := []frankenphp.Option{ opts := []frankenphp.Option{
frankenphp.WithContext(f.ctx),
frankenphp.WithLogger(f.logger),
frankenphp.WithNumThreads(f.NumThreads), frankenphp.WithNumThreads(f.NumThreads),
frankenphp.WithMaxThreads(f.MaxThreads), frankenphp.WithMaxThreads(f.MaxThreads),
frankenphp.WithLogger(f.logger),
frankenphp.WithMetrics(f.metrics), frankenphp.WithMetrics(f.metrics),
frankenphp.WithPhpIni(f.PhpIni), frankenphp.WithPhpIni(f.PhpIni),
frankenphp.WithMaxWaitTime(f.MaxWaitTime), frankenphp.WithMaxWaitTime(f.MaxWaitTime),
@@ -159,7 +163,11 @@ func (f *FrankenPHPApp) Start() error {
} }
func (f *FrankenPHPApp) Stop() error { func (f *FrankenPHPApp) Stop() error {
f.logger.Info("FrankenPHP stopped 🐘") ctx := caddy.ActiveContext()
if f.logger.Enabled(caddy.ActiveContext(), slog.LevelInfo) {
f.logger.LogAttrs(ctx, slog.LevelInfo, "FrankenPHP stopped 🐘")
}
// attempt a graceful shutdown if caddy is exiting // attempt a graceful shutdown if caddy is exiting
// note: Exiting() is currently marked as 'experimental' // note: Exiting() is currently marked as 'experimental'

4
cgi.go
View File

@@ -212,7 +212,7 @@ func addPreparedEnvToServer(fc *frankenPHPContext, trackVarsArray *C.zval) {
//export go_register_variables //export go_register_variables
func go_register_variables(threadIndex C.uintptr_t, trackVarsArray *C.zval) { func go_register_variables(threadIndex C.uintptr_t, trackVarsArray *C.zval) {
thread := phpThreads[threadIndex] thread := phpThreads[threadIndex]
fc := thread.getRequestContext() fc := thread.frankenPHPContext()
if fc.request != nil { if fc.request != nil {
addKnownVariablesToServer(thread, fc, trackVarsArray) addKnownVariablesToServer(thread, fc, trackVarsArray)
@@ -279,7 +279,7 @@ func splitPos(path string, splitPath []string) int {
//export go_update_request_info //export go_update_request_info
func go_update_request_info(threadIndex C.uintptr_t, info *C.sapi_request_info) C.bool { func go_update_request_info(threadIndex C.uintptr_t, info *C.sapi_request_info) C.bool {
thread := phpThreads[threadIndex] thread := phpThreads[threadIndex]
fc := thread.getRequestContext() fc := thread.frankenPHPContext()
request := fc.request request := fc.request
if request == nil { if request == nil {

View File

@@ -38,6 +38,11 @@ type frankenPHPContext struct {
startedAt time.Time startedAt time.Time
} }
type contextHolder struct {
ctx context.Context
frankenPHPContext *frankenPHPContext
}
// fromContext extracts the frankenPHPContext from a context. // fromContext extracts the frankenPHPContext from a context.
func fromContext(ctx context.Context) (fctx *frankenPHPContext, ok bool) { func fromContext(ctx context.Context) (fctx *frankenPHPContext, ok bool) {
fctx, ok = ctx.Value(contextKey).(*frankenPHPContext) fctx, ok = ctx.Value(contextKey).(*frankenPHPContext)
@@ -63,7 +68,7 @@ func NewRequestWithContext(r *http.Request, opts ...RequestOption) (*http.Reques
} }
if fc.logger == nil { if fc.logger == nil {
fc.logger = logger fc.logger = globalLogger
} }
if fc.documentRoot == "" { if fc.documentRoot == "" {

View File

@@ -802,7 +802,7 @@ static void frankenphp_register_variables(zval *track_vars_array) {
} }
static void frankenphp_log_message(const char *message, int syslog_type_int) { static void frankenphp_log_message(const char *message, int syslog_type_int) {
go_log((char *)message, syslog_type_int); go_log(thread_index, (char *)message, syslog_type_int);
} }
static char *frankenphp_getenv(const char *name, size_t name_len) { static char *frankenphp_getenv(const char *name, size_t name_len) {

View File

@@ -60,8 +60,10 @@ var (
isRunning bool isRunning bool
onServerShutdown []func() onServerShutdown []func()
loggerMu sync.RWMutex // Set default values to make Shutdown() idempotent
logger *slog.Logger globalMu sync.Mutex
globalCtx = context.Background()
globalLogger = slog.Default()
metrics Metrics = nullMetrics{} metrics Metrics = nullMetrics{}
@@ -231,15 +233,19 @@ func Init(options ...Option) error {
} }
} }
if opt.logger == nil { globalMu.Lock()
// set a default logger
// to disable logging, set the logger to slog.New(slog.NewTextHandler(io.Discard, nil)) if opt.ctx != nil {
opt.logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) globalCtx = opt.ctx
opt.ctx = nil
} }
loggerMu.Lock() if opt.logger != nil {
logger = opt.logger globalLogger = opt.logger
loggerMu.Unlock() opt.logger = nil
}
globalMu.Unlock()
if opt.metrics != nil { if opt.metrics != nil {
metrics = opt.metrics metrics = opt.metrics
@@ -262,11 +268,16 @@ func Init(options ...Option) error {
if config.ZTS { if config.ZTS {
if !config.ZendMaxExecutionTimers && runtime.GOOS == "linux" { if !config.ZendMaxExecutionTimers && runtime.GOOS == "linux" {
logger.Warn(`Zend Max Execution Timers are not enabled, timeouts (e.g. "max_execution_time") are disabled, recompile PHP with the "--enable-zend-max-execution-timers" configuration option to fix this issue`) if globalLogger.Enabled(globalCtx, slog.LevelWarn) {
globalLogger.LogAttrs(globalCtx, slog.LevelWarn, `Zend Max Execution Timers are not enabled, timeouts (e.g. "max_execution_time") are disabled, recompile PHP with the "--enable-zend-max-execution-timers" configuration option to fix this issue`)
}
} }
} else { } else {
opt.numThreads = 1 opt.numThreads = 1
logger.Warn(`ZTS is not enabled, only 1 thread will be available, recompile PHP using the "--enable-zts" configuration option or performance will be degraded`)
if globalLogger.Enabled(globalCtx, slog.LevelWarn) {
globalLogger.LogAttrs(globalCtx, slog.LevelWarn, `ZTS is not enabled, only 1 thread will be available, recompile PHP using the "--enable-zts" configuration option or performance will be degraded`)
}
} }
mainThread, err := initPHPThreads(opt.numThreads, opt.maxThreads, opt.phpIni) mainThread, err := initPHPThreads(opt.numThreads, opt.maxThreads, opt.phpIni)
@@ -274,7 +285,7 @@ func Init(options ...Option) error {
return err return err
} }
regularRequestChan = make(chan *frankenPHPContext, opt.numThreads-workerThreadCount) regularRequestChan = make(chan contextHolder, opt.numThreads-workerThreadCount)
regularThreads = make([]*phpThread, 0, opt.numThreads-workerThreadCount) regularThreads = make([]*phpThread, 0, opt.numThreads-workerThreadCount)
for i := 0; i < opt.numThreads-workerThreadCount; i++ { for i := 0; i < opt.numThreads-workerThreadCount; i++ {
convertToRegularThread(getInactivePHPThread()) convertToRegularThread(getInactivePHPThread())
@@ -286,10 +297,12 @@ func Init(options ...Option) error {
initAutoScaling(mainThread) initAutoScaling(mainThread)
ctx := context.Background() if globalLogger.Enabled(globalCtx, slog.LevelInfo) {
logger.LogAttrs(ctx, slog.LevelInfo, "FrankenPHP started 🐘", slog.String("php_version", Version().Version), slog.Int("num_threads", mainThread.numThreads), slog.Int("max_threads", mainThread.maxThreads)) globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "FrankenPHP started 🐘", slog.String("php_version", Version().Version), slog.Int("num_threads", mainThread.numThreads), slog.Int("max_threads", mainThread.maxThreads))
if EmbeddedAppPath != "" {
logger.LogAttrs(ctx, slog.LevelInfo, "embedded PHP app 📦", slog.String("path", EmbeddedAppPath)) if EmbeddedAppPath != "" {
globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "embedded PHP app 📦", slog.String("path", EmbeddedAppPath))
}
} }
// register the startup/shutdown hooks (mainly useful for extensions) // register the startup/shutdown hooks (mainly useful for extensions)
@@ -329,7 +342,11 @@ func Shutdown() {
} }
isRunning = false isRunning = false
logger.Debug("FrankenPHP shut down") if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "FrankenPHP shut down")
}
resetGlobals()
} }
// ServeHTTP executes a PHP script according to the given context. // ServeHTTP executes a PHP script according to the given context.
@@ -343,7 +360,11 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error
return ErrNotRunning return ErrNotRunning
} }
fc, ok := fromContext(request.Context()) ctx := request.Context()
fc, ok := fromContext(ctx)
ch := contextHolder{ctx, fc}
if !ok { if !ok {
return ErrInvalidRequest return ErrInvalidRequest
} }
@@ -356,16 +377,17 @@ func ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) error
// Detect if a worker is available to handle this request // Detect if a worker is available to handle this request
if fc.worker != nil { if fc.worker != nil {
return fc.worker.handleRequest(fc) return fc.worker.handleRequest(ch)
} }
// If no worker was available, send the request to non-worker threads // If no worker was available, send the request to non-worker threads
return handleRequestWithRegularPHPThreads(fc) return handleRequestWithRegularPHPThreads(ch)
} }
//export go_ub_write //export go_ub_write
func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t, C.bool) { func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t, C.bool) {
fc := phpThreads[threadIndex].getRequestContext() thread := phpThreads[threadIndex]
fc := thread.frankenPHPContext()
if fc.isDone { if fc.isDone {
return 0, C.bool(true) return 0, C.bool(true)
@@ -380,14 +402,27 @@ func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t,
writer = fc.responseWriter writer = fc.responseWriter
} }
var ctx context.Context
i, e := writer.Write(unsafe.Slice((*byte)(unsafe.Pointer(cBuf)), length)) i, e := writer.Write(unsafe.Slice((*byte)(unsafe.Pointer(cBuf)), length))
if e != nil { if e != nil {
fc.logger.LogAttrs(context.Background(), slog.LevelWarn, "write error", slog.Any("error", e)) ctx = thread.context()
if fc.logger.Enabled(ctx, slog.LevelWarn) {
fc.logger.LogAttrs(ctx, slog.LevelWarn, "write error", slog.Any("error", e))
}
} }
if fc.responseWriter == nil { if fc.responseWriter == nil {
// probably starting a worker script, log the output // probably starting a worker script, log the output
fc.logger.Info(writer.(*bytes.Buffer).String())
if ctx == nil {
ctx = thread.context()
}
if fc.logger.Enabled(ctx, slog.LevelInfo) {
fc.logger.LogAttrs(ctx, slog.LevelInfo, writer.(*bytes.Buffer).String())
}
} }
return C.size_t(i), C.bool(fc.clientHasClosed()) return C.size_t(i), C.bool(fc.clientHasClosed())
@@ -396,12 +431,15 @@ func go_ub_write(threadIndex C.uintptr_t, cBuf *C.char, length C.int) (C.size_t,
//export go_apache_request_headers //export go_apache_request_headers
func go_apache_request_headers(threadIndex C.uintptr_t) (*C.go_string, C.size_t) { func go_apache_request_headers(threadIndex C.uintptr_t) (*C.go_string, C.size_t) {
thread := phpThreads[threadIndex] thread := phpThreads[threadIndex]
fc := thread.getRequestContext() ctx := thread.context()
fc := thread.frankenPHPContext()
if fc.responseWriter == nil { if fc.responseWriter == nil {
// worker mode, not handling a request // worker mode, not handling a request
logger.LogAttrs(context.Background(), slog.LevelDebug, "apache_request_headers() called in non-HTTP context", slog.String("worker", fc.worker.name)) if globalLogger.Enabled(ctx, slog.LevelDebug) {
globalLogger.LogAttrs(ctx, slog.LevelDebug, "apache_request_headers() called in non-HTTP context", slog.String("worker", fc.worker.name))
}
return nil, 0 return nil, 0
} }
@@ -429,10 +467,13 @@ func go_apache_request_headers(threadIndex C.uintptr_t) (*C.go_string, C.size_t)
return sd, C.size_t(len(fc.request.Header)) return sd, C.size_t(len(fc.request.Header))
} }
func addHeader(fc *frankenPHPContext, cString *C.char, length C.int) { func addHeader(ctx context.Context, fc *frankenPHPContext, cString *C.char, length C.int) {
key, val := splitRawHeader(cString, int(length)) key, val := splitRawHeader(cString, int(length))
if key == "" { if key == "" {
fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "invalid header", slog.String("header", C.GoStringN(cString, length))) if fc.logger.Enabled(ctx, slog.LevelDebug) {
fc.logger.LogAttrs(ctx, slog.LevelDebug, "invalid header", slog.String("header", C.GoStringN(cString, length)))
}
return return
} }
fc.responseWriter.Header().Add(key, val) fc.responseWriter.Header().Add(key, val)
@@ -471,8 +512,8 @@ func splitRawHeader(rawHeader *C.char, length int) (string, string) {
//export go_write_headers //export go_write_headers
func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_llist) C.bool { func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_llist) C.bool {
fc := phpThreads[threadIndex].getRequestContext() thread := phpThreads[threadIndex]
fc := thread.frankenPHPContext()
if fc == nil { if fc == nil {
return C.bool(false) return C.bool(false)
} }
@@ -490,7 +531,7 @@ func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_lli
for current != nil { for current != nil {
h := (*C.sapi_header_struct)(unsafe.Pointer(&(current.data))) h := (*C.sapi_header_struct)(unsafe.Pointer(&(current.data)))
addHeader(fc, h.header, C.int(h.header_len)) addHeader(thread.context(), fc, h.header, C.int(h.header_len))
current = current.next current = current.next
} }
@@ -499,13 +540,18 @@ func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_lli
// go panics on invalid status code // go panics on invalid status code
// https://github.com/golang/go/blob/9b8742f2e79438b9442afa4c0a0139d3937ea33f/src/net/http/server.go#L1162 // https://github.com/golang/go/blob/9b8742f2e79438b9442afa4c0a0139d3937ea33f/src/net/http/server.go#L1162
if goStatus < 100 || goStatus > 999 { if goStatus < 100 || goStatus > 999 {
logger.Warn(fmt.Sprintf("Invalid response status code %v", goStatus)) ctx := thread.context()
if globalLogger.Enabled(ctx, slog.LevelWarn) {
globalLogger.LogAttrs(ctx, slog.LevelWarn, "Invalid response status code", slog.Int("status_code", goStatus))
}
goStatus = 500 goStatus = 500
} }
fc.responseWriter.WriteHeader(goStatus) fc.responseWriter.WriteHeader(goStatus)
if goStatus >= 100 && goStatus < 200 { if goStatus < 200 {
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
h := fc.responseWriter.Header() h := fc.responseWriter.Header()
for k := range h { for k := range h {
@@ -518,8 +564,13 @@ func go_write_headers(threadIndex C.uintptr_t, status C.int, headers *C.zend_lli
//export go_sapi_flush //export go_sapi_flush
func go_sapi_flush(threadIndex C.uintptr_t) bool { func go_sapi_flush(threadIndex C.uintptr_t) bool {
fc := phpThreads[threadIndex].getRequestContext() thread := phpThreads[threadIndex]
if fc == nil || fc.responseWriter == nil { fc := thread.frankenPHPContext()
if fc == nil {
return false
}
if fc.responseWriter == nil {
return false return false
} }
@@ -528,7 +579,11 @@ func go_sapi_flush(threadIndex C.uintptr_t) bool {
} }
if err := http.NewResponseController(fc.responseWriter).Flush(); err != nil { if err := http.NewResponseController(fc.responseWriter).Flush(); err != nil {
logger.LogAttrs(context.Background(), slog.LevelWarn, "the current responseWriter is not a flusher, if you are not using a custom build, please report this issue", slog.Any("error", err)) ctx := thread.context()
if globalLogger.Enabled(ctx, slog.LevelWarn) {
globalLogger.LogAttrs(ctx, slog.LevelWarn, "the current responseWriter is not a flusher, if you are not using a custom build, please report this issue", slog.Any("error", err))
}
} }
return false return false
@@ -536,7 +591,7 @@ func go_sapi_flush(threadIndex C.uintptr_t) bool {
//export go_read_post //export go_read_post
func go_read_post(threadIndex C.uintptr_t, cBuf *C.char, countBytes C.size_t) (readBytes C.size_t) { func go_read_post(threadIndex C.uintptr_t, cBuf *C.char, countBytes C.size_t) (readBytes C.size_t) {
fc := phpThreads[threadIndex].getRequestContext() fc := phpThreads[threadIndex].frankenPHPContext()
if fc.responseWriter == nil { if fc.responseWriter == nil {
return 0 return 0
@@ -555,7 +610,7 @@ func go_read_post(threadIndex C.uintptr_t, cBuf *C.char, countBytes C.size_t) (r
//export go_read_cookies //export go_read_cookies
func go_read_cookies(threadIndex C.uintptr_t) *C.char { func go_read_cookies(threadIndex C.uintptr_t) *C.char {
request := phpThreads[threadIndex].getRequestContext().request request := phpThreads[threadIndex].frankenPHPContext().request
if request == nil { if request == nil {
return nil return nil
} }
@@ -573,7 +628,8 @@ func go_read_cookies(threadIndex C.uintptr_t) *C.char {
} }
//export go_log //export go_log
func go_log(message *C.char, level C.int) { func go_log(threadIndex C.uintptr_t, message *C.char, level C.int) {
ctx := phpThreads[threadIndex].context()
m := C.GoString(message) m := C.GoString(message)
var le syslogLevel var le syslogLevel
@@ -585,21 +641,29 @@ func go_log(message *C.char, level C.int) {
switch le { switch le {
case syslogLevelEmerg, syslogLevelAlert, syslogLevelCrit, syslogLevelErr: case syslogLevelEmerg, syslogLevelAlert, syslogLevelCrit, syslogLevelErr:
logger.LogAttrs(context.Background(), slog.LevelError, m, slog.String("syslog_level", syslogLevel(level).String())) if globalLogger.Enabled(ctx, slog.LevelError) {
globalLogger.LogAttrs(ctx, slog.LevelError, m, slog.String("syslog_level", syslogLevel(level).String()))
}
case syslogLevelWarn: case syslogLevelWarn:
logger.LogAttrs(context.Background(), slog.LevelWarn, m, slog.String("syslog_level", syslogLevel(level).String())) if globalLogger.Enabled(ctx, slog.LevelWarn) {
globalLogger.LogAttrs(ctx, slog.LevelWarn, m, slog.String("syslog_level", syslogLevel(level).String()))
}
case syslogLevelDebug: case syslogLevelDebug:
logger.LogAttrs(context.Background(), slog.LevelDebug, m, slog.String("syslog_level", syslogLevel(level).String())) if globalLogger.Enabled(ctx, slog.LevelDebug) {
globalLogger.LogAttrs(ctx, slog.LevelDebug, m, slog.String("syslog_level", syslogLevel(level).String()))
}
default: default:
logger.LogAttrs(context.Background(), slog.LevelInfo, m, slog.String("syslog_level", syslogLevel(level).String())) if globalLogger.Enabled(ctx, slog.LevelInfo) {
globalLogger.LogAttrs(ctx, slog.LevelInfo, m, slog.String("syslog_level", syslogLevel(level).String()))
}
} }
} }
//export go_is_context_done //export go_is_context_done
func go_is_context_done(threadIndex C.uintptr_t) C.bool { func go_is_context_done(threadIndex C.uintptr_t) C.bool {
return C.bool(phpThreads[threadIndex].getRequestContext().isDone) return C.bool(phpThreads[threadIndex].frankenPHPContext().isDone)
} }
// ExecuteScriptCLI executes the PHP script passed as parameter. // ExecuteScriptCLI executes the PHP script passed as parameter.
@@ -648,3 +712,11 @@ func timeoutChan(timeout time.Duration) <-chan time.Time {
return time.After(timeout) return time.After(timeout)
} }
func resetGlobals() {
globalMu.Lock()
globalCtx = context.Background()
globalLogger = slog.Default()
workers = nil
globalMu.Unlock()
}

View File

@@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"flag"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -136,6 +137,16 @@ func testPost(url string, body string, handler func(http.ResponseWriter, *http.R
return testRequest(req, handler, t) return testRequest(req, handler, t)
} }
func TestMain(m *testing.M) {
flag.Parse()
if !testing.Verbose() {
slog.SetDefault(slog.New(slog.DiscardHandler))
}
os.Exit(m.Run())
}
func TestHelloWorld_module(t *testing.T) { testHelloWorld(t, nil) } func TestHelloWorld_module(t *testing.T) { testHelloWorld(t, nil) }
func TestHelloWorld_worker(t *testing.T) { func TestHelloWorld_worker(t *testing.T) {
testHelloWorld(t, &testOptions{workerScript: "index.php"}) testHelloWorld(t, &testOptions{workerScript: "index.php"})

View File

@@ -10,8 +10,11 @@ import (
) )
func main() { func main() {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
if err := frankenphp.Init(frankenphp.WithLogger(logger)); err != nil {
if err := frankenphp.Init(frankenphp.WithContext(ctx), frankenphp.WithLogger(logger)); err != nil {
panic(err) panic(err)
} }
defer frankenphp.Shutdown() defer frankenphp.Shutdown()
@@ -32,6 +35,9 @@ func main() {
port = "8080" port = "8080"
} }
logger.LogAttrs(context.Background(), slog.LevelError, "server error", slog.Any("error", http.ListenAndServe(":"+port, nil))) if logger.Enabled(ctx, slog.LevelError) {
logger.LogAttrs(ctx, slog.LevelError, "server error", slog.Any("error", http.ListenAndServe(":"+port, nil)))
}
os.Exit(1) os.Exit(1)
} }

View File

@@ -3,7 +3,6 @@
package watcher package watcher
import ( import (
"context"
"log/slog" "log/slog"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -81,9 +80,10 @@ func isValidEventType(eventType int) bool {
// 0:dir,1:file,2:hard_link,3:sym_link,4:watcher,5:other, // 0:dir,1:file,2:hard_link,3:sym_link,4:watcher,5:other,
func isValidPathType(pathType int, fileName string) bool { func isValidPathType(pathType int, fileName string) bool {
if pathType == 4 { if pathType == 4 && logger.Enabled(ctx, slog.LevelDebug) {
logger.LogAttrs(context.Background(), slog.LevelDebug, "special edant/watcher event", slog.String("fileName", fileName)) logger.LogAttrs(ctx, slog.LevelDebug, "special edant/watcher event", slog.String("fileName", fileName))
} }
return pathType <= 2 return pathType <= 2
} }
@@ -163,9 +163,14 @@ func matchPattern(pattern string, fileName string) bool {
if pattern == "" { if pattern == "" {
return true return true
} }
patternMatches, err := filepath.Match(pattern, fileName) patternMatches, err := filepath.Match(pattern, fileName)
if err != nil { if err != nil {
logger.LogAttrs(context.Background(), slog.LevelError, "failed to match filename", slog.String("file", fileName), slog.Any("error", err)) if logger.Enabled(ctx, slog.LevelError) {
logger.LogAttrs(ctx, slog.LevelError, "failed to match filename", slog.String("file", fileName), slog.Any("error", err))
}
return false return false
} }

View File

@@ -2,9 +2,12 @@
package watcher package watcher
import "log/slog" import (
"context"
"log/slog"
)
func InitWatcher(filePatterns []string, callback func(), logger *slog.Logger) error { func InitWatcher(ct context.Context, filePatterns []string, callback func(), logger *slog.Logger) error {
logger.Error("watcher support is not enabled") logger.Error("watcher support is not enabled")
return nil return nil

View File

@@ -43,11 +43,13 @@ var (
activeWatcher *watcher activeWatcher *watcher
// after stopping the watcher we will wait for eventual reloads to finish // after stopping the watcher we will wait for eventual reloads to finish
reloadWaitGroup sync.WaitGroup reloadWaitGroup sync.WaitGroup
// we are passing the context from the main package to the watcher
ctx context.Context
// we are passing the logger from the main package to the watcher // we are passing the logger from the main package to the watcher
logger *slog.Logger logger *slog.Logger
) )
func InitWatcher(filePatterns []string, callback func(), slogger *slog.Logger) error { func InitWatcher(ct context.Context, filePatterns []string, callback func(), slogger *slog.Logger) error {
if len(filePatterns) == 0 { if len(filePatterns) == 0 {
return nil return nil
} }
@@ -55,9 +57,10 @@ func InitWatcher(filePatterns []string, callback func(), slogger *slog.Logger) e
return ErrAlreadyStarted return ErrAlreadyStarted
} }
watcherIsActive.Store(true) watcherIsActive.Store(true)
ctx = ct
logger = slogger logger = slogger
activeWatcher = &watcher{callback: callback} activeWatcher = &watcher{callback: callback}
err := activeWatcher.startWatching(filePatterns) err := activeWatcher.startWatching(ctx, filePatterns)
if err != nil { if err != nil {
return err return err
} }
@@ -71,7 +74,11 @@ func DrainWatcher() {
return return
} }
watcherIsActive.Store(false) watcherIsActive.Store(false)
logger.Debug("stopping watcher")
if logger.Enabled(ctx, slog.LevelDebug) {
logger.LogAttrs(ctx, slog.LevelDebug, "stopping watcher")
}
activeWatcher.stopWatching() activeWatcher.stopWatching()
reloadWaitGroup.Wait() reloadWaitGroup.Wait()
activeWatcher = nil activeWatcher = nil
@@ -79,15 +86,19 @@ func DrainWatcher() {
// TODO: how to test this? // TODO: how to test this?
func retryWatching(watchPattern *watchPattern) { func retryWatching(watchPattern *watchPattern) {
ctx := context.Background()
failureMu.Lock() failureMu.Lock()
defer failureMu.Unlock() defer failureMu.Unlock()
if watchPattern.failureCount >= maxFailureCount { if watchPattern.failureCount >= maxFailureCount {
logger.LogAttrs(ctx, slog.LevelWarn, "giving up watching", slog.String("dir", watchPattern.dir)) if logger.Enabled(ctx, slog.LevelWarn) {
logger.LogAttrs(ctx, slog.LevelWarn, "giving up watching", slog.String("dir", watchPattern.dir))
}
return return
} }
logger.LogAttrs(ctx, slog.LevelInfo, "watcher was closed prematurely, retrying...", slog.String("dir", watchPattern.dir))
if logger.Enabled(ctx, slog.LevelInfo) {
logger.LogAttrs(ctx, slog.LevelInfo, "watcher was closed prematurely, retrying...", slog.String("dir", watchPattern.dir))
}
watchPattern.failureCount++ watchPattern.failureCount++
session, err := startSession(watchPattern) session, err := startSession(watchPattern)
@@ -106,7 +117,7 @@ func retryWatching(watchPattern *watchPattern) {
}() }()
} }
func (w *watcher) startWatching(filePatterns []string) error { func (w *watcher) startWatching(ctx context.Context, filePatterns []string) error {
w.trigger = make(chan string) w.trigger = make(chan string)
w.stop = make(chan struct{}) w.stop = make(chan struct{})
w.sessions = make([]C.uintptr_t, len(filePatterns)) w.sessions = make([]C.uintptr_t, len(filePatterns))
@@ -134,26 +145,29 @@ func (w *watcher) stopWatching() {
} }
func startSession(w *watchPattern) (C.uintptr_t, error) { func startSession(w *watchPattern) (C.uintptr_t, error) {
ctx := context.Background()
handle := cgo.NewHandle(w) handle := cgo.NewHandle(w)
cDir := C.CString(w.dir) cDir := C.CString(w.dir)
defer C.free(unsafe.Pointer(cDir)) defer C.free(unsafe.Pointer(cDir))
watchSession := C.start_new_watcher(cDir, C.uintptr_t(handle)) watchSession := C.start_new_watcher(cDir, C.uintptr_t(handle))
if watchSession != 0 { if watchSession != 0 {
logger.LogAttrs(ctx, slog.LevelDebug, "watching", slog.String("dir", w.dir), slog.Any("patterns", w.patterns)) if logger.Enabled(ctx, slog.LevelDebug) {
logger.LogAttrs(ctx, slog.LevelDebug, "watching", slog.String("dir", w.dir), slog.Any("patterns", w.patterns))
}
return watchSession, nil return watchSession, nil
} }
logger.LogAttrs(ctx, slog.LevelError, "couldn't start watching", slog.String("dir", w.dir))
if logger.Enabled(ctx, slog.LevelError) {
logger.LogAttrs(ctx, slog.LevelError, "couldn't start watching", slog.String("dir", w.dir))
}
return watchSession, ErrUnableToStartWatching return watchSession, ErrUnableToStartWatching
} }
func stopSession(session C.uintptr_t) { func stopSession(session C.uintptr_t) {
success := C.stop_watcher(session) success := C.stop_watcher(session)
if success == 0 { if success == 0 && logger.Enabled(ctx, slog.LevelWarn) {
logger.Warn("couldn't close the watcher") logger.LogAttrs(ctx, slog.LevelWarn, "couldn't close the watcher")
} }
} }
@@ -195,7 +209,11 @@ func listenForFileEvents(triggerWatcher chan string, stopWatcher chan struct{})
timer.Reset(debounceDuration) timer.Reset(debounceDuration)
case <-timer.C: case <-timer.C:
timer.Stop() timer.Stop()
logger.LogAttrs(context.Background(), slog.LevelInfo, "filesystem change detected", slog.String("file", lastChangedFile))
if logger.Enabled(ctx, slog.LevelInfo) {
logger.LogAttrs(ctx, slog.LevelInfo, "filesystem change detected", slog.String("file", lastChangedFile))
}
scheduleReload() scheduleReload()
} }
} }

View File

@@ -1,6 +1,7 @@
package frankenphp package frankenphp
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"time" "time"
@@ -19,6 +20,7 @@ type WorkerOption func(*workerOpt) error
// //
// If you change this, also update the Caddy module and the documentation. // If you change this, also update the Caddy module and the documentation.
type opt struct { type opt struct {
ctx context.Context
numThreads int numThreads int
maxThreads int maxThreads int
workers []workerOpt workers []workerOpt
@@ -42,6 +44,15 @@ type workerOpt struct {
onServerShutdown func() onServerShutdown func()
} }
// WithContext sets the main context to use.
func WithContext(ctx context.Context) Option {
return func(h *opt) error {
h.ctx = ctx
return nil
}
}
// WithNumThreads configures the number of PHP threads to start. // WithNumThreads configures the number of PHP threads to start.
func WithNumThreads(numThreads int) Option { func WithNumThreads(numThreads int) Option {
return func(o *opt) error { return func(o *opt) error {

View File

@@ -8,7 +8,6 @@ package frankenphp
// #include "frankenphp.h" // #include "frankenphp.h"
import "C" import "C"
import ( import (
"context"
"log/slog" "log/slog"
"strings" "strings"
"sync" "sync"
@@ -171,7 +170,9 @@ func (mainThread *phpMainThread) setAutomaticMaxThreads() {
maxAllowedThreads := totalSysMemory / uint64(perThreadMemoryLimit) maxAllowedThreads := totalSysMemory / uint64(perThreadMemoryLimit)
mainThread.maxThreads = int(maxAllowedThreads) mainThread.maxThreads = int(maxAllowedThreads)
logger.LogAttrs(context.Background(), slog.LevelDebug, "Automatic thread limit", slog.Int("perThreadMemoryLimitMB", int(perThreadMemoryLimit/1024/1024)), slog.Int("maxThreads", mainThread.maxThreads)) if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "Automatic thread limit", slog.Int("perThreadMemoryLimitMB", int(perThreadMemoryLimit/1024/1024)), slog.Int("maxThreads", mainThread.maxThreads))
}
} }
//export go_frankenphp_shutdown_main_thread //export go_frankenphp_shutdown_main_thread

View File

@@ -18,8 +18,15 @@ import (
var testDataPath, _ = filepath.Abs("./testdata") var testDataPath, _ = filepath.Abs("./testdata")
func setupGlobals(t *testing.T) {
t.Helper()
t.Cleanup(Shutdown)
resetGlobals()
}
func TestStartAndStopTheMainThreadWithOneInactiveThread(t *testing.T) { func TestStartAndStopTheMainThreadWithOneInactiveThread(t *testing.T) {
logger = slog.New(slog.NewTextHandler(io.Discard, nil))
_, err := initPHPThreads(1, 1, nil) // boot 1 thread _, err := initPHPThreads(1, 1, nil) // boot 1 thread
assert.NoError(t, err) assert.NoError(t, err)
@@ -28,12 +35,13 @@ func TestStartAndStopTheMainThreadWithOneInactiveThread(t *testing.T) {
assert.True(t, phpThreads[0].state.is(stateInactive)) assert.True(t, phpThreads[0].state.is(stateInactive))
drainPHPThreads() drainPHPThreads()
assert.Nil(t, phpThreads) assert.Nil(t, phpThreads)
} }
func TestTransitionRegularThreadToWorkerThread(t *testing.T) { func TestTransitionRegularThreadToWorkerThread(t *testing.T) {
workers = nil setupGlobals(t)
logger = slog.New(slog.NewTextHandler(io.Discard, nil))
_, err := initPHPThreads(1, 1, nil) _, err := initPHPThreads(1, 1, nil)
assert.NoError(t, err) assert.NoError(t, err)
@@ -42,7 +50,7 @@ func TestTransitionRegularThreadToWorkerThread(t *testing.T) {
assert.IsType(t, &regularThread{}, phpThreads[0].handler) assert.IsType(t, &regularThread{}, phpThreads[0].handler)
// transition to worker thread // transition to worker thread
worker := getDummyWorker("transition-worker-1.php") worker := getDummyWorker(t, "transition-worker-1.php")
convertToWorkerThread(phpThreads[0], worker) convertToWorkerThread(phpThreads[0], worker)
assert.IsType(t, &workerThread{}, phpThreads[0].handler) assert.IsType(t, &workerThread{}, phpThreads[0].handler)
assert.Len(t, worker.threads, 1) assert.Len(t, worker.threads, 1)
@@ -57,12 +65,12 @@ func TestTransitionRegularThreadToWorkerThread(t *testing.T) {
} }
func TestTransitionAThreadBetween2DifferentWorkers(t *testing.T) { func TestTransitionAThreadBetween2DifferentWorkers(t *testing.T) {
workers = nil setupGlobals(t)
logger = slog.New(slog.NewTextHandler(io.Discard, nil))
_, err := initPHPThreads(1, 1, nil) _, err := initPHPThreads(1, 1, nil)
assert.NoError(t, err) assert.NoError(t, err)
firstWorker := getDummyWorker("transition-worker-1.php") firstWorker := getDummyWorker(t, "transition-worker-1.php")
secondWorker := getDummyWorker("transition-worker-2.php") secondWorker := getDummyWorker(t, "transition-worker-2.php")
// convert to first worker thread // convert to first worker thread
convertToWorkerThread(phpThreads[0], firstWorker) convertToWorkerThread(phpThreads[0], firstWorker)
@@ -151,13 +159,13 @@ func TestTransitionThreadsWhileDoingRequests(t *testing.T) {
} }
func TestFinishBootingAWorkerScript(t *testing.T) { func TestFinishBootingAWorkerScript(t *testing.T) {
workers = nil setupGlobals(t)
logger = slog.New(slog.NewTextHandler(io.Discard, nil))
_, err := initPHPThreads(1, 1, nil) _, err := initPHPThreads(1, 1, nil)
assert.NoError(t, err) assert.NoError(t, err)
// boot the worker // boot the worker
worker := getDummyWorker("transition-worker-1.php") worker := getDummyWorker(t, "transition-worker-1.php")
convertToWorkerThread(phpThreads[0], worker) convertToWorkerThread(phpThreads[0], worker)
phpThreads[0].state.waitFor(stateReady) phpThreads[0].state.waitFor(stateReady)
@@ -193,16 +201,20 @@ func TestReturnAnErrorIf2ModuleWorkersHaveTheSameName(t *testing.T) {
assert.Error(t, err2, "two workers cannot have the same name") assert.Error(t, err2, "two workers cannot have the same name")
} }
func getDummyWorker(fileName string) *worker { func getDummyWorker(t *testing.T, fileName string) *worker {
t.Helper()
if workers == nil { if workers == nil {
workers = []*worker{} workers = []*worker{}
} }
worker, _ := newWorker(workerOpt{ worker, _ := newWorker(workerOpt{
fileName: testDataPath + "/" + fileName, fileName: testDataPath + "/" + fileName,
num: 1, num: 1,
maxConsecutiveFailures: defaultMaxConsecutiveFailures, maxConsecutiveFailures: defaultMaxConsecutiveFailures,
}) })
workers = append(workers, worker) workers = append(workers, worker)
return worker return worker
} }

View File

@@ -5,7 +5,6 @@ package frankenphp
import "C" import "C"
import ( import (
"context" "context"
"log/slog"
"runtime" "runtime"
"sync" "sync"
"unsafe" "unsafe"
@@ -16,7 +15,7 @@ import (
type phpThread struct { type phpThread struct {
runtime.Pinner runtime.Pinner
threadIndex int threadIndex int
requestChan chan *frankenPHPContext requestChan chan contextHolder
drainChan chan struct{} drainChan chan struct{}
handlerMu sync.Mutex handlerMu sync.Mutex
handler threadHandler handler threadHandler
@@ -29,13 +28,14 @@ type threadHandler interface {
name() string name() string
beforeScriptExecution() string beforeScriptExecution() string
afterScriptExecution(exitStatus int) afterScriptExecution(exitStatus int)
getRequestContext() *frankenPHPContext context() context.Context
frankenPHPContext() *frankenPHPContext
} }
func newPHPThread(threadIndex int) *phpThread { func newPHPThread(threadIndex int) *phpThread {
return &phpThread{ return &phpThread{
threadIndex: threadIndex, threadIndex: threadIndex,
requestChan: make(chan *frankenPHPContext), requestChan: make(chan contextHolder),
state: newThreadState(), state: newThreadState(),
} }
} }
@@ -44,7 +44,6 @@ func newPHPThread(threadIndex int) *phpThread {
func (thread *phpThread) boot() { func (thread *phpThread) boot() {
// thread must be in reserved state to boot // thread must be in reserved state to boot
if !thread.state.compareAndSwap(stateReserved, stateBooting) && !thread.state.compareAndSwap(stateBootRequested, stateBooting) { if !thread.state.compareAndSwap(stateReserved, stateBooting) && !thread.state.compareAndSwap(stateBootRequested, stateBooting) {
logger.Error("thread is not in reserved state: " + thread.state.name())
panic("thread is not in reserved state: " + thread.state.name()) panic("thread is not in reserved state: " + thread.state.name())
} }
@@ -56,7 +55,6 @@ func (thread *phpThread) boot() {
// start the actual posix thread - TODO: try this with go threads instead // start the actual posix thread - TODO: try this with go threads instead
if !C.frankenphp_new_php_thread(C.uintptr_t(thread.threadIndex)) { if !C.frankenphp_new_php_thread(C.uintptr_t(thread.threadIndex)) {
logger.LogAttrs(context.Background(), slog.LevelError, "unable to create thread", slog.Int("thread", thread.threadIndex))
panic("unable to create thread") panic("unable to create thread")
} }
@@ -100,12 +98,17 @@ func (thread *phpThread) setHandler(handler threadHandler) {
func (thread *phpThread) transitionToNewHandler() string { func (thread *phpThread) transitionToNewHandler() string {
thread.state.set(stateTransitionInProgress) thread.state.set(stateTransitionInProgress)
thread.state.waitFor(stateTransitionComplete) thread.state.waitFor(stateTransitionComplete)
// execute beforeScriptExecution of the new handler // execute beforeScriptExecution of the new handler
return thread.handler.beforeScriptExecution() return thread.handler.beforeScriptExecution()
} }
func (thread *phpThread) getRequestContext() *frankenPHPContext { func (thread *phpThread) frankenPHPContext() *frankenPHPContext {
return thread.handler.getRequestContext() return thread.handler.frankenPHPContext()
}
func (thread *phpThread) context() context.Context {
return thread.handler.context()
} }
func (thread *phpThread) name() string { func (thread *phpThread) name() string {

View File

@@ -4,7 +4,6 @@ package frankenphp
//#include <sys/resource.h> //#include <sys/resource.h>
import "C" import "C"
import ( import (
"context"
"errors" "errors"
"log/slog" "log/slog"
"sync" "sync"
@@ -54,7 +53,11 @@ func initAutoScaling(mainThread *phpMainThread) {
func drainAutoScaling() { func drainAutoScaling() {
scalingMu.Lock() scalingMu.Lock()
logger.LogAttrs(context.Background(), slog.LevelDebug, "shutting down autoscaling", slog.Int("autoScaledThreads", len(autoScaledThreads)))
if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "shutting down autoscaling", slog.Int("autoScaledThreads", len(autoScaledThreads)))
}
scalingMu.Unlock() scalingMu.Unlock()
} }
@@ -94,13 +97,18 @@ func scaleWorkerThread(worker *worker) {
thread, err := addWorkerThread(worker) thread, err := addWorkerThread(worker)
if err != nil { if err != nil {
logger.LogAttrs(context.Background(), slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.String("worker", worker.name), slog.Any("error", err)) if globalLogger.Enabled(globalCtx, slog.LevelWarn) {
globalLogger.LogAttrs(globalCtx, slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.String("worker", worker.name), slog.Any("error", err))
}
return return
} }
autoScaledThreads = append(autoScaledThreads, thread) autoScaledThreads = append(autoScaledThreads, thread)
logger.LogAttrs(context.Background(), slog.LevelInfo, "upscaling worker thread", slog.String("worker", worker.name), slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads))) if globalLogger.Enabled(globalCtx, slog.LevelInfo) {
globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "upscaling worker thread", slog.String("worker", worker.name), slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads)))
}
} }
// scaleRegularThread adds a regular PHP thread automatically // scaleRegularThread adds a regular PHP thread automatically
@@ -119,13 +127,18 @@ func scaleRegularThread() {
thread, err := addRegularThread() thread, err := addRegularThread()
if err != nil { if err != nil {
logger.LogAttrs(context.Background(), slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.Any("error", err)) if globalLogger.Enabled(globalCtx, slog.LevelWarn) {
globalLogger.LogAttrs(globalCtx, slog.LevelWarn, "could not increase max_threads, consider raising this limit", slog.Any("error", err))
}
return return
} }
autoScaledThreads = append(autoScaledThreads, thread) autoScaledThreads = append(autoScaledThreads, thread)
logger.LogAttrs(context.Background(), slog.LevelInfo, "upscaling regular thread", slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads))) if globalLogger.Enabled(globalCtx, slog.LevelInfo) {
globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "upscaling regular thread", slog.Int("thread", thread.threadIndex), slog.Int("num_threads", len(autoScaledThreads)))
}
} }
func startUpscalingThreads(maxScaledThreads int, scale chan *frankenPHPContext, done chan struct{}) { func startUpscalingThreads(maxScaledThreads int, scale chan *frankenPHPContext, done chan struct{}) {
@@ -204,7 +217,10 @@ func deactivateThreads() {
convertToInactiveThread(thread) convertToInactiveThread(thread)
stoppedThreadCount++ stoppedThreadCount++
autoScaledThreads = append(autoScaledThreads[:i], autoScaledThreads[i+1:]...) autoScaledThreads = append(autoScaledThreads[:i], autoScaledThreads[i+1:]...)
logger.LogAttrs(context.Background(), slog.LevelInfo, "downscaling thread", slog.Int("thread", thread.threadIndex), slog.Int64("wait_time", waitTime), slog.Int("num_threads", len(autoScaledThreads)))
if globalLogger.Enabled(globalCtx, slog.LevelInfo) {
globalLogger.LogAttrs(globalCtx, slog.LevelInfo, "downscaling thread", slog.Int("thread", thread.threadIndex), slog.Int64("wait_time", waitTime), slog.Int("num_threads", len(autoScaledThreads)))
}
continue continue
} }

View File

@@ -1,5 +1,7 @@
package frankenphp package frankenphp
import "context"
// representation of a thread with no work assigned to it // representation of a thread with no work assigned to it
// implements the threadHandler interface // implements the threadHandler interface
// each inactive thread weighs around ~350KB // each inactive thread weighs around ~350KB
@@ -18,6 +20,7 @@ func (handler *inactiveThread) beforeScriptExecution() string {
switch thread.state.get() { switch thread.state.get() {
case stateTransitionRequested: case stateTransitionRequested:
return thread.transitionToNewHandler() return thread.transitionToNewHandler()
case stateBooting, stateTransitionComplete: case stateBooting, stateTransitionComplete:
thread.state.set(stateInactive) thread.state.set(stateInactive)
@@ -25,11 +28,14 @@ func (handler *inactiveThread) beforeScriptExecution() string {
thread.state.markAsWaiting(true) thread.state.markAsWaiting(true)
thread.state.waitFor(stateTransitionRequested, stateShuttingDown) thread.state.waitFor(stateTransitionRequested, stateShuttingDown)
thread.state.markAsWaiting(false) thread.state.markAsWaiting(false)
return handler.beforeScriptExecution() return handler.beforeScriptExecution()
case stateShuttingDown: case stateShuttingDown:
// signal to stop // signal to stop
return "" return ""
} }
panic("unexpected state: " + thread.state.name()) panic("unexpected state: " + thread.state.name())
} }
@@ -37,7 +43,11 @@ func (handler *inactiveThread) afterScriptExecution(int) {
panic("inactive threads should not execute scripts") panic("inactive threads should not execute scripts")
} }
func (handler *inactiveThread) getRequestContext() *frankenPHPContext { func (handler *inactiveThread) frankenPHPContext() *frankenPHPContext {
return nil
}
func (handler *inactiveThread) context() context.Context {
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package frankenphp package frankenphp
import ( import (
"context"
"sync" "sync"
) )
@@ -8,15 +9,16 @@ import (
// executes PHP scripts in a web context // executes PHP scripts in a web context
// implements the threadHandler interface // implements the threadHandler interface
type regularThread struct { type regularThread struct {
state *threadState contextHolder
thread *phpThread
requestContext *frankenPHPContext state *threadState
thread *phpThread
} }
var ( var (
regularThreads []*phpThread regularThreads []*phpThread
regularThreadMu = &sync.RWMutex{} regularThreadMu = &sync.RWMutex{}
regularRequestChan chan *frankenPHPContext regularRequestChan chan contextHolder
) )
func convertToRegularThread(thread *phpThread) { func convertToRegularThread(thread *phpThread) {
@@ -33,25 +35,33 @@ func (handler *regularThread) beforeScriptExecution() string {
case stateTransitionRequested: case stateTransitionRequested:
detachRegularThread(handler.thread) detachRegularThread(handler.thread)
return handler.thread.transitionToNewHandler() return handler.thread.transitionToNewHandler()
case stateTransitionComplete: case stateTransitionComplete:
handler.state.set(stateReady) handler.state.set(stateReady)
return handler.waitForRequest() return handler.waitForRequest()
case stateReady: case stateReady:
return handler.waitForRequest() return handler.waitForRequest()
case stateShuttingDown: case stateShuttingDown:
detachRegularThread(handler.thread) detachRegularThread(handler.thread)
// signal to stop // signal to stop
return "" return ""
} }
panic("unexpected state: " + handler.state.name()) panic("unexpected state: " + handler.state.name())
} }
func (handler *regularThread) afterScriptExecution(int) { func (handler *regularThread) afterScriptExecution(_ int) {
handler.afterRequest() handler.afterRequest()
} }
func (handler *regularThread) getRequestContext() *frankenPHPContext { func (handler *regularThread) frankenPHPContext() *frankenPHPContext {
return handler.requestContext return handler.contextHolder.frankenPHPContext
}
func (handler *regularThread) context() context.Context {
return handler.ctx
} }
func (handler *regularThread) name() string { func (handler *regularThread) name() string {
@@ -64,32 +74,36 @@ func (handler *regularThread) waitForRequest() string {
handler.state.markAsWaiting(true) handler.state.markAsWaiting(true)
var fc *frankenPHPContext var ch contextHolder
select { select {
case <-handler.thread.drainChan: case <-handler.thread.drainChan:
// go back to beforeScriptExecution // go back to beforeScriptExecution
return handler.beforeScriptExecution() return handler.beforeScriptExecution()
case fc = <-regularRequestChan: case ch = <-regularRequestChan:
} }
handler.requestContext = fc handler.ctx = ch.ctx
handler.contextHolder.frankenPHPContext = ch.frankenPHPContext
handler.state.markAsWaiting(false) handler.state.markAsWaiting(false)
// set the scriptFilename that should be executed // set the scriptFilename that should be executed
return fc.scriptFilename return handler.contextHolder.frankenPHPContext.scriptFilename
} }
func (handler *regularThread) afterRequest() { func (handler *regularThread) afterRequest() {
handler.requestContext.closeContext() handler.contextHolder.frankenPHPContext.closeContext()
handler.requestContext = nil handler.contextHolder.frankenPHPContext = nil
handler.ctx = nil
} }
func handleRequestWithRegularPHPThreads(fc *frankenPHPContext) error { func handleRequestWithRegularPHPThreads(ch contextHolder) error {
metrics.StartRequest() metrics.StartRequest()
select { select {
case regularRequestChan <- fc: case regularRequestChan <- ch:
// a thread was available to handle the request immediately // a thread was available to handle the request immediately
<-fc.done <-ch.frankenPHPContext.done
metrics.StopRequest() metrics.StopRequest()
return nil return nil
@@ -101,19 +115,19 @@ func handleRequestWithRegularPHPThreads(fc *frankenPHPContext) error {
metrics.QueuedRequest() metrics.QueuedRequest()
for { for {
select { select {
case regularRequestChan <- fc: case regularRequestChan <- ch:
metrics.DequeuedRequest() metrics.DequeuedRequest()
<-fc.done <-ch.frankenPHPContext.done
metrics.StopRequest() metrics.StopRequest()
return nil return nil
case scaleChan <- fc: case scaleChan <- ch.frankenPHPContext:
// the request has triggered scaling, continue to wait for a thread // the request has triggered scaling, continue to wait for a thread
case <-timeoutChan(maxWaitTime): case <-timeoutChan(maxWaitTime):
// the request has timed out stalling // the request has timed out stalling
metrics.DequeuedRequest() metrics.DequeuedRequest()
fc.reject(ErrMaxWaitTimeExceeded) ch.frankenPHPContext.reject(ErrMaxWaitTimeExceeded)
return ErrMaxWaitTimeExceeded return ErrMaxWaitTimeExceeded
} }

View File

@@ -1,6 +1,7 @@
package frankenphp package frankenphp
import ( import (
"context"
"sync" "sync"
) )
@@ -60,11 +61,15 @@ func (handler *taskThread) beforeScriptExecution() string {
panic("unexpected state: " + thread.state.name()) panic("unexpected state: " + thread.state.name())
} }
func (handler *taskThread) afterScriptExecution(int) { func (handler *taskThread) afterScriptExecution(_ int) {
panic("task threads should not execute scripts") panic("task threads should not execute scripts")
} }
func (handler *taskThread) getRequestContext() *frankenPHPContext { func (handler *taskThread) frankenPHPContext() *frankenPHPContext {
return nil
}
func (handler *taskThread) context() context.Context {
return nil return nil
} }

View File

@@ -4,7 +4,6 @@ package frankenphp
import "C" import "C"
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"path/filepath" "path/filepath"
"time" "time"
@@ -15,13 +14,15 @@ import (
// executes the PHP worker script in a loop // executes the PHP worker script in a loop
// implements the threadHandler interface // implements the threadHandler interface
type workerThread struct { type workerThread struct {
state *threadState state *threadState
thread *phpThread thread *phpThread
worker *worker worker *worker
dummyContext *frankenPHPContext dummyFrankenPHPContext *frankenPHPContext
workerContext *frankenPHPContext dummyContext context.Context
backoff *exponentialBackoff workerFrankenPHPContext *frankenPHPContext
isBootingScript bool // true if the worker has not reached frankenphp_handle_request yet workerContext context.Context
backoff *exponentialBackoff
isBootingScript bool // true if the worker has not reached frankenphp_handle_request yet
} }
func convertToWorkerThread(thread *phpThread, worker *worker) { func convertToWorkerThread(thread *phpThread, worker *worker) {
@@ -58,16 +59,20 @@ func (handler *workerThread) beforeScriptExecution() string {
if handler.worker.onThreadReady != nil { if handler.worker.onThreadReady != nil {
handler.worker.onThreadReady(handler.thread.threadIndex) handler.worker.onThreadReady(handler.thread.threadIndex)
} }
setupWorkerScript(handler, handler.worker) setupWorkerScript(handler, handler.worker)
return handler.worker.fileName return handler.worker.fileName
case stateShuttingDown: case stateShuttingDown:
if handler.worker.onThreadShutdown != nil { if handler.worker.onThreadShutdown != nil {
handler.worker.onThreadShutdown(handler.thread.threadIndex) handler.worker.onThreadShutdown(handler.thread.threadIndex)
} }
handler.worker.detachThread(handler.thread) handler.worker.detachThread(handler.thread)
// signal to stop // signal to stop
return "" return ""
} }
panic("unexpected state: " + handler.state.name()) panic("unexpected state: " + handler.state.name())
} }
@@ -75,7 +80,14 @@ func (handler *workerThread) afterScriptExecution(exitStatus int) {
tearDownWorkerScript(handler, exitStatus) tearDownWorkerScript(handler, exitStatus)
} }
func (handler *workerThread) getRequestContext() *frankenPHPContext { func (handler *workerThread) frankenPHPContext() *frankenPHPContext {
if handler.workerFrankenPHPContext != nil {
return handler.workerFrankenPHPContext
}
return handler.dummyFrankenPHPContext
}
func (handler *workerThread) context() context.Context {
if handler.workerContext != nil { if handler.workerContext != nil {
return handler.workerContext return handler.workerContext
} }
@@ -105,23 +117,29 @@ func setupWorkerScript(handler *workerThread, worker *worker) {
panic(err) panic(err)
} }
ctx := context.WithValue(globalCtx, contextKey, fc)
fc.worker = worker fc.worker = worker
handler.dummyContext = fc handler.dummyFrankenPHPContext = fc
handler.dummyContext = ctx
handler.isBootingScript = true handler.isBootingScript = true
clearSandboxedEnv(handler.thread) clearSandboxedEnv(handler.thread)
logger.LogAttrs(context.Background(), slog.LevelDebug, "starting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex))
if globalLogger.Enabled(ctx, slog.LevelDebug) {
globalLogger.LogAttrs(ctx, slog.LevelDebug, "starting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex))
}
} }
func tearDownWorkerScript(handler *workerThread, exitStatus int) { func tearDownWorkerScript(handler *workerThread, exitStatus int) {
worker := handler.worker worker := handler.worker
handler.dummyFrankenPHPContext = nil
handler.dummyContext = nil handler.dummyContext = nil
ctx := context.Background()
// if the worker request is not nil, the script might have crashed // if the worker request is not nil, the script might have crashed
// make sure to close the worker request context // make sure to close the worker request context
if handler.workerContext != nil { if handler.workerFrankenPHPContext != nil {
handler.workerContext.closeContext() handler.workerFrankenPHPContext.closeContext()
handler.workerFrankenPHPContext = nil
handler.workerContext = nil handler.workerContext = nil
} }
@@ -129,7 +147,10 @@ func tearDownWorkerScript(handler *workerThread, exitStatus int) {
if exitStatus == 0 && !handler.isBootingScript { if exitStatus == 0 && !handler.isBootingScript {
metrics.StopWorker(worker.name, StopReasonRestart) metrics.StopWorker(worker.name, StopReasonRestart)
handler.backoff.recordSuccess() handler.backoff.recordSuccess()
logger.LogAttrs(ctx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus))
if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus))
}
return return
} }
@@ -139,20 +160,26 @@ func tearDownWorkerScript(handler *workerThread, exitStatus int) {
if !handler.isBootingScript { if !handler.isBootingScript {
// fatal error (could be due to exit(1), timeouts, etc.) // fatal error (could be due to exit(1), timeouts, etc.)
logger.LogAttrs(ctx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus)) if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "restarting", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("exit_status", exitStatus))
}
return return
} }
logger.LogAttrs(ctx, slog.LevelError, "worker script has not reached frankenphp_handle_request()", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex)) if globalLogger.Enabled(globalCtx, slog.LevelError) {
globalLogger.LogAttrs(globalCtx, slog.LevelError, "worker script has not reached frankenphp_handle_request()", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex))
}
// panic after exponential backoff if the worker has never reached frankenphp_handle_request // panic after exponential backoff if the worker has never reached frankenphp_handle_request
if handler.backoff.recordFailure() { if handler.backoff.recordFailure() {
if !watcherIsEnabled && !handler.state.is(stateReady) { if !watcherIsEnabled && !handler.state.is(stateReady) {
logger.LogAttrs(ctx, slog.LevelError, "too many consecutive worker failures", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("failures", handler.backoff.failureCount))
panic("too many consecutive worker failures") panic("too many consecutive worker failures")
} }
logger.LogAttrs(ctx, slog.LevelWarn, "many consecutive worker failures", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("failures", handler.backoff.failureCount))
if globalLogger.Enabled(globalCtx, slog.LevelWarn) {
globalLogger.LogAttrs(globalCtx, slog.LevelWarn, "many consecutive worker failures", slog.String("worker", worker.name), slog.Int("thread", handler.thread.threadIndex), slog.Int("failures", handler.backoff.failureCount))
}
} }
} }
@@ -161,8 +188,9 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) {
// unpin any memory left over from previous requests // unpin any memory left over from previous requests
handler.thread.Unpin() handler.thread.Unpin()
ctx := context.Background() if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
logger.LogAttrs(ctx, slog.LevelDebug, "waiting for request", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "waiting for request", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex))
}
// Clear the first dummy request created to initialize the worker // Clear the first dummy request created to initialize the worker
if handler.isBootingScript { if handler.isBootingScript {
@@ -182,10 +210,12 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) {
handler.state.markAsWaiting(true) handler.state.markAsWaiting(true)
var fc *frankenPHPContext var requestCH contextHolder
select { select {
case <-handler.thread.drainChan: case <-handler.thread.drainChan:
logger.LogAttrs(ctx, slog.LevelDebug, "shutting down", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) if globalLogger.Enabled(globalCtx, slog.LevelDebug) {
globalLogger.LogAttrs(globalCtx, slog.LevelDebug, "shutting down", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex))
}
// flush the opcache when restarting due to watcher or admin api // flush the opcache when restarting due to watcher or admin api
// note: this is done right before frankenphp_handle_request() returns 'false' // note: this is done right before frankenphp_handle_request() returns 'false'
@@ -194,20 +224,23 @@ func (handler *workerThread) waitForWorkerRequest() (bool, any) {
} }
return false, nil return false, nil
case fc = <-handler.thread.requestChan: case requestCH = <-handler.thread.requestChan:
case fc = <-handler.worker.requestChan: case requestCH = <-handler.worker.requestChan:
} }
handler.workerContext = fc handler.workerContext = requestCH.ctx
handler.workerFrankenPHPContext = requestCH.frankenPHPContext
handler.state.markAsWaiting(false) handler.state.markAsWaiting(false)
if fc.request == nil { if globalLogger.Enabled(requestCH.ctx, slog.LevelDebug) {
logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex)) if handler.workerFrankenPHPContext.request == nil {
} else { globalLogger.LogAttrs(requestCH.ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex))
logger.LogAttrs(ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex), slog.String("url", fc.request.RequestURI)) } else {
globalLogger.LogAttrs(requestCH.ctx, slog.LevelDebug, "request handling started", slog.String("worker", handler.worker.name), slog.Int("thread", handler.thread.threadIndex), slog.String("url", handler.workerFrankenPHPContext.request.RequestURI))
}
} }
return true, fc.handlerParameters return true, handler.workerFrankenPHPContext.handlerParameters
} }
// go_frankenphp_worker_handle_request_start is called at the start of every php request served. // go_frankenphp_worker_handle_request_start is called at the start of every php request served.
@@ -240,23 +273,28 @@ func go_frankenphp_worker_handle_request_start(threadIndex C.uintptr_t) (C.bool,
//export go_frankenphp_finish_worker_request //export go_frankenphp_finish_worker_request
func go_frankenphp_finish_worker_request(threadIndex C.uintptr_t, retval *C.zval) { func go_frankenphp_finish_worker_request(threadIndex C.uintptr_t, retval *C.zval) {
thread := phpThreads[threadIndex] thread := phpThreads[threadIndex]
fc := thread.getRequestContext() ctx := thread.context()
fc := ctx.Value(contextKey).(*frankenPHPContext)
if retval != nil { if retval != nil {
r, err := GoValue[any](unsafe.Pointer(retval)) r, err := GoValue[any](unsafe.Pointer(retval))
if err != nil { if err != nil && globalLogger.Enabled(ctx, slog.LevelError) {
logger.Error(fmt.Sprintf("cannot convert return value: %s", err)) globalLogger.LogAttrs(ctx, slog.LevelError, "cannot convert return value", slog.Any("error", err), slog.Int("thread", thread.threadIndex))
} }
fc.handlerReturn = r fc.handlerReturn = r
} }
fc.closeContext() fc.closeContext()
thread.handler.(*workerThread).workerFrankenPHPContext = nil
thread.handler.(*workerThread).workerContext = nil thread.handler.(*workerThread).workerContext = nil
if fc.request == nil { if globalLogger.Enabled(ctx, slog.LevelDebug) {
fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex)) if fc.request == nil {
} else { fc.logger.LogAttrs(ctx, slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex))
fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) } else {
fc.logger.LogAttrs(ctx, slog.LevelDebug, "request handling finished", slog.String("worker", fc.worker.name), slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI))
}
} }
} }
@@ -265,9 +303,12 @@ func go_frankenphp_finish_worker_request(threadIndex C.uintptr_t, retval *C.zval
//export go_frankenphp_finish_php_request //export go_frankenphp_finish_php_request
func go_frankenphp_finish_php_request(threadIndex C.uintptr_t) { func go_frankenphp_finish_php_request(threadIndex C.uintptr_t) {
thread := phpThreads[threadIndex] thread := phpThreads[threadIndex]
fc := thread.getRequestContext() fc := thread.frankenPHPContext()
fc.closeContext() fc.closeContext()
fc.logger.LogAttrs(context.Background(), slog.LevelDebug, "request handling finished", slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI)) ctx := thread.context()
if fc.logger.Enabled(ctx, slog.LevelDebug) {
fc.logger.LogAttrs(ctx, slog.LevelDebug, "request handling finished", slog.Int("thread", thread.threadIndex), slog.String("url", fc.request.RequestURI))
}
} }

View File

@@ -13,7 +13,7 @@ import (
// this is necessary if tests make use of PHP's internal allocation // this is necessary if tests make use of PHP's internal allocation
func testOnDummyPHPThread(t *testing.T, test func()) { func testOnDummyPHPThread(t *testing.T, test func()) {
t.Helper() t.Helper()
logger = slog.New(slog.NewTextHandler(io.Discard, nil)) globalLogger = slog.New(slog.NewTextHandler(io.Discard, nil))
_, err := initPHPThreads(1, 1, nil) // boot 1 thread _, err := initPHPThreads(1, 1, nil) // boot 1 thread
assert.NoError(t, err) assert.NoError(t, err)
handler := convertToTaskThread(phpThreads[0]) handler := convertToTaskThread(phpThreads[0])

View File

@@ -19,7 +19,7 @@ type worker struct {
fileName string fileName string
num int num int
env PreparedEnv env PreparedEnv
requestChan chan *frankenPHPContext requestChan chan contextHolder
threads []*phpThread threads []*phpThread
threadMutex sync.RWMutex threadMutex sync.RWMutex
allowPathMatching bool allowPathMatching bool
@@ -66,7 +66,7 @@ func initWorkers(opt []workerOpt) error {
} }
watcherIsEnabled = true watcherIsEnabled = true
if err := watcher.InitWatcher(directoriesToWatch, RestartWorkers, logger); err != nil { if err := watcher.InitWatcher(globalCtx, directoriesToWatch, RestartWorkers, globalLogger); err != nil {
return err return err
} }
@@ -128,7 +128,7 @@ func newWorker(o workerOpt) (*worker, error) {
fileName: absFileName, fileName: absFileName,
num: o.num, num: o.num,
env: o.env, env: o.env,
requestChan: make(chan *frankenPHPContext), requestChan: make(chan contextHolder),
threads: make([]*phpThread, 0, o.num), threads: make([]*phpThread, 0, o.num),
allowPathMatching: allowPathMatching, allowPathMatching: allowPathMatching,
maxConsecutiveFailures: o.maxConsecutiveFailures, maxConsecutiveFailures: o.maxConsecutiveFailures,
@@ -228,17 +228,17 @@ func (worker *worker) countThreads() int {
return l return l
} }
func (worker *worker) handleRequest(fc *frankenPHPContext) error { func (worker *worker) handleRequest(ch contextHolder) error {
metrics.StartWorkerRequest(worker.name) metrics.StartWorkerRequest(worker.name)
// dispatch requests to all worker threads in order // dispatch requests to all worker threads in order
worker.threadMutex.RLock() worker.threadMutex.RLock()
for _, thread := range worker.threads { for _, thread := range worker.threads {
select { select {
case thread.requestChan <- fc: case thread.requestChan <- ch:
worker.threadMutex.RUnlock() worker.threadMutex.RUnlock()
<-fc.done <-ch.frankenPHPContext.done
metrics.StopWorkerRequest(worker.name, time.Since(fc.startedAt)) metrics.StopWorkerRequest(worker.name, time.Since(ch.frankenPHPContext.startedAt))
return nil return nil
default: default:
@@ -251,19 +251,19 @@ func (worker *worker) handleRequest(fc *frankenPHPContext) error {
metrics.QueuedWorkerRequest(worker.name) metrics.QueuedWorkerRequest(worker.name)
for { for {
select { select {
case worker.requestChan <- fc: case worker.requestChan <- ch:
metrics.DequeuedWorkerRequest(worker.name) metrics.DequeuedWorkerRequest(worker.name)
<-fc.done <-ch.frankenPHPContext.done
metrics.StopWorkerRequest(worker.name, time.Since(fc.startedAt)) metrics.StopWorkerRequest(worker.name, time.Since(ch.frankenPHPContext.startedAt))
return nil return nil
case scaleChan <- fc: case scaleChan <- ch.frankenPHPContext:
// the request has triggered scaling, continue to wait for a thread // the request has triggered scaling, continue to wait for a thread
case <-timeoutChan(maxWaitTime): case <-timeoutChan(maxWaitTime):
// the request has timed out stalling // the request has timed out stalling
metrics.DequeuedWorkerRequest(worker.name) metrics.DequeuedWorkerRequest(worker.name)
fc.reject(ErrMaxWaitTimeExceeded) ch.frankenPHPContext.reject(ErrMaxWaitTimeExceeded)
return ErrMaxWaitTimeExceeded return ErrMaxWaitTimeExceeded
} }

View File

@@ -1,6 +1,7 @@
package frankenphp package frankenphp
import ( import (
"context"
"net/http" "net/http"
) )
@@ -10,7 +11,7 @@ type Workers interface {
// The generated HTTP response will be written through the provided writer. // The generated HTTP response will be written through the provided writer.
SendRequest(rw http.ResponseWriter, r *http.Request) error SendRequest(rw http.ResponseWriter, r *http.Request) error
// SendMessage calls the closure passed to frankenphp_handle_request(), passes message as a parameter, and returns the value produced by the closure. // SendMessage calls the closure passed to frankenphp_handle_request(), passes message as a parameter, and returns the value produced by the closure.
SendMessage(message any, rw http.ResponseWriter) (any, error) SendMessage(ctx context.Context, message any, rw http.ResponseWriter) (any, error)
// NumThreads returns the number of available threads. // NumThreads returns the number of available threads.
NumThreads() int NumThreads() int
} }
@@ -43,14 +44,14 @@ func (w *extensionWorkers) NumThreads() int {
} }
// EXPERIMENTAL: SendMessage sends a message to the worker and waits for a response. // EXPERIMENTAL: SendMessage sends a message to the worker and waits for a response.
func (w *extensionWorkers) SendMessage(message any, rw http.ResponseWriter) (any, error) { func (w *extensionWorkers) SendMessage(ctx context.Context, message any, rw http.ResponseWriter) (any, error) {
fc := newFrankenPHPContext() fc := newFrankenPHPContext()
fc.logger = logger fc.logger = globalLogger
fc.worker = w.internalWorker fc.worker = w.internalWorker
fc.responseWriter = rw fc.responseWriter = rw
fc.handlerParameters = message fc.handlerParameters = message
err := w.internalWorker.handleRequest(fc) err := w.internalWorker.handleRequest(contextHolder{context.WithValue(ctx, contextKey, fc), fc})
return fc.handlerReturn, err return fc.handlerReturn, err
} }

View File

@@ -69,7 +69,7 @@ func TestWorkerExtensionSendMessage(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(Shutdown) t.Cleanup(Shutdown)
ret, err := externalWorker.SendMessage("Hello Workers", nil) ret, err := externalWorker.SendMessage(t.Context(), "Hello Workers", nil)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "received message: Hello Workers", ret) assert.Equal(t, "received message: Hello Workers", ret)