From 3abda4fbb67daf4ec46bb433433eb82a42456ed6 Mon Sep 17 00:00:00 2001 From: Rob Landers Date: Wed, 9 Nov 2022 15:09:45 +0100 Subject: [PATCH] feat: handle aborted connection (#95) * Handle aborted connection * Handle when writing as well * return bytes written * optimize return * remove goroutine * fix style * Add tests * add missing newline --- frankenphp.c | 10 +++- frankenphp.go | 27 +++++++-- frankenphp_test.go | 98 ++++++++++++++++++++++++++++++++ testdata/connectionStatusLog.php | 17 ++++++ 4 files changed, 144 insertions(+), 8 deletions(-) create mode 100644 testdata/connectionStatusLog.php diff --git a/frankenphp.c b/frankenphp.c index 2c93220d..d4ad182b 100644 --- a/frankenphp.c +++ b/frankenphp.c @@ -405,7 +405,13 @@ static size_t frankenphp_ub_write(const char *str, size_t str_length) return 0; } - return go_ub_write(ctx->current_request ? ctx->current_request : ctx->main_request, (char *) str, str_length); + struct go_ub_write_return result = go_ub_write(ctx->current_request ? ctx->current_request : ctx->main_request, (char *) str, str_length); + + if (result.r1) { + php_handle_aborted_connection(); + } + + return result.r0; } static int frankenphp_send_headers(sapi_headers_struct *sapi_headers) @@ -445,7 +451,7 @@ static void frankenphp_sapi_flush(void *server_context) if (!ctx || ctx->current_request == 0) return; - go_sapi_flush(ctx->current_request); + if (go_sapi_flush(ctx->current_request)) php_handle_aborted_connection(); } static size_t frankenphp_read_post(char *buffer, size_t count_bytes) diff --git a/frankenphp.go b/frankenphp.go index 902034f1..c3031eec 100644 --- a/frankenphp.go +++ b/frankenphp.go @@ -124,13 +124,22 @@ type FrankenPHPContext struct { populated bool authPassword string - // Whether the request is already closed + // Whether the request is already closed by us closed sync.Once responseWriter http.ResponseWriter done chan interface{} } +func clientHasClosed(r *http.Request) bool { + select { + case <-r.Context().Done(): + return true + default: + return false + } +} + // NewRequestWithContext creates a new FrankenPHP request context. func NewRequestWithContext(r *http.Request, documentRoot string, l *zap.Logger) *http.Request { if l == nil { @@ -407,7 +416,7 @@ func go_execute_script(rh unsafe.Pointer) { } //export go_ub_write -func go_ub_write(rh C.uintptr_t, cString *C.char, length C.int) C.size_t { +func go_ub_write(rh C.uintptr_t, cString *C.char, length C.int) (C.size_t, C.bool) { r := cgo.Handle(rh).Value().(*http.Request) fc, _ := FromContext(r.Context()) @@ -426,7 +435,7 @@ func go_ub_write(rh C.uintptr_t, cString *C.char, length C.int) C.size_t { fc.Logger.Info(writer.(*bytes.Buffer).String()) } - return C.size_t(i) + return C.size_t(i), C.bool(clientHasClosed(r)) } //export go_register_variables @@ -486,20 +495,26 @@ func go_write_header(rh C.uintptr_t, status C.int) { } //export go_sapi_flush -func go_sapi_flush(rh C.uintptr_t) { +func go_sapi_flush(rh C.uintptr_t) bool { r := cgo.Handle(rh).Value().(*http.Request) fc := r.Context().Value(contextKey).(*FrankenPHPContext) if fc.responseWriter == nil { - return + return true } flusher, ok := fc.responseWriter.(http.Flusher) if !ok { - return + return true + } + + if clientHasClosed(r) { + return true } flusher.Flush() + + return false } //export go_read_post diff --git a/frankenphp_test.go b/frankenphp_test.go index 099d2f99..c30d478d 100644 --- a/frankenphp_test.go +++ b/frankenphp_test.go @@ -1,6 +1,7 @@ package frankenphp_test import ( + "context" "fmt" "io" "log" @@ -15,6 +16,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/dunglas/frankenphp" "github.com/stretchr/testify/assert" @@ -388,6 +390,102 @@ func testLog(t *testing.T, opts *testOptions) { }, opts) } +func TestConnectionAbortNormal_module(t *testing.T) { testConnectionAbortNormal(t, &testOptions{}) } +func TestConnectionAbortNormal_worker(t *testing.T) { + testConnectionAbortNormal(t, &testOptions{workerScript: "connectionStatusLog.php"}) +} +func testConnectionAbortNormal(t *testing.T, opts *testOptions) { + logger, logs := observer.New(zap.InfoLevel) + opts.logger = zap.New(logger) + + runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { + req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/connectionStatusLog.php?i=%d", i), nil) + w := httptest.NewRecorder() + + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + cancel() + handler(w, req) + + // todo: remove conditions on wall clock to avoid race conditions/flakiness + time.Sleep(1000 * time.Microsecond) + var found bool + searched := fmt.Sprintf("request %d: 1", i) + for _, entry := range logs.All() { + if entry.Message == searched { + found = true + break + } + } + + assert.True(t, found) + }, opts) +} + +func TestConnectionAbortFlush_module(t *testing.T) { testConnectionAbortFlush(t, &testOptions{}) } +func TestConnectionAbortFlush_worker(t *testing.T) { + testConnectionAbortFlush(t, &testOptions{workerScript: "connectionStatusLog.php"}) +} +func testConnectionAbortFlush(t *testing.T, opts *testOptions) { + logger, logs := observer.New(zap.InfoLevel) + opts.logger = zap.New(logger) + + runTest(t, func(handler func(w http.ResponseWriter, response *http.Request), _ *httptest.Server, i int) { + req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/connectionStatusLog.php?i=%d&flush", i), nil) + w := httptest.NewRecorder() + + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + cancel() + handler(w, req) + + // todo: remove conditions on wall clock to avoid race conditions/flakiness + time.Sleep(1000 * time.Microsecond) + var found bool + searched := fmt.Sprintf("request %d: 1", i) + for _, entry := range logs.All() { + if entry.Message == searched { + found = true + break + } + } + + assert.True(t, found) + }, opts) +} + +func TestConnectionAbortFinish_module(t *testing.T) { testConnectionAbortFinish(t, &testOptions{}) } +func TestConnectionAbortFinish_worker(t *testing.T) { + testConnectionAbortFinish(t, &testOptions{workerScript: "connectionStatusLog.php"}) +} +func testConnectionAbortFinish(t *testing.T, opts *testOptions) { + logger, logs := observer.New(zap.InfoLevel) + opts.logger = zap.New(logger) + + runTest(t, func(handler func(w http.ResponseWriter, response *http.Request), _ *httptest.Server, i int) { + req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/connectionStatusLog.php?i=%d&finish", i), nil) + w := httptest.NewRecorder() + + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + cancel() + handler(w, req) + + // todo: remove conditions on wall clock to avoid race conditions/flakiness + time.Sleep(1000 * time.Microsecond) + var found bool + searched := fmt.Sprintf("request %d: 0", i) + for _, entry := range logs.All() { + if entry.Message == searched { + found = true + break + } + } + + assert.True(t, found) + }, opts) +} + func TestException_module(t *testing.T) { testException(t, &testOptions{}) } func TestException_worker(t *testing.T) { testException(t, &testOptions{workerScript: "exception.php"}) diff --git a/testdata/connectionStatusLog.php b/testdata/connectionStatusLog.php new file mode 100644 index 00000000..3e0a055c --- /dev/null +++ b/testdata/connectionStatusLog.php @@ -0,0 +1,17 @@ +