From 3183a465d71a13535e52589bb85b987176872fcd Mon Sep 17 00:00:00 2001
From: zeripath <art27@cantab.net>
Date: Mon, 31 May 2021 07:18:11 +0100
Subject: [PATCH] Make modules/context.Context a context.Context (#16031)

* Make modules/context.Context a context.Context

Signed-off-by: Andrew Thornton <art27@cantab.net>

* Simplify context calls

Signed-off-by: Andrew Thornton <art27@cantab.net>

* Set the base context for requests to the HammerContext

Signed-off-by: Andrew Thornton <art27@cantab.net>

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
---
 modules/context/context.go      | 22 +++++++++++++++++++++-
 modules/graceful/server_http.go |  3 +++
 routers/admin/users.go          |  4 ++--
 routers/api/v1/admin/user.go    |  4 ++--
 routers/events/events.go        |  2 +-
 routers/install.go              |  4 ++--
 routers/private/manager.go      |  2 +-
 routers/private/restore_repo.go |  2 +-
 routers/repo/blame.go           |  2 +-
 routers/repo/lfs.go             |  2 +-
 routers/user/auth.go            | 12 ++++++------
 routers/user/auth_openid.go     |  4 ++--
 routers/user/setting/account.go |  2 +-
 services/archiver/archiver.go   |  4 ++--
 14 files changed, 46 insertions(+), 23 deletions(-)

diff --git a/modules/context/context.go b/modules/context/context.go
index d812d7b58c..d45e9ff87c 100644
--- a/modules/context/context.go
+++ b/modules/context/context.go
@@ -509,7 +509,7 @@ func (ctx *Context) ParamsInt64(p string) int64 {
 
 // SetParams set params into routes
 func (ctx *Context) SetParams(k, v string) {
-	chiCtx := chi.RouteContext(ctx.Req.Context())
+	chiCtx := chi.RouteContext(ctx)
 	chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v))
 }
 
@@ -528,6 +528,26 @@ func (ctx *Context) Status(status int) {
 	ctx.Resp.WriteHeader(status)
 }
 
+// Deadline is part of the interface for context.Context and we pass this to the request context
+func (ctx *Context) Deadline() (deadline time.Time, ok bool) {
+	return ctx.Req.Context().Deadline()
+}
+
+// Done is part of the interface for context.Context and we pass this to the request context
+func (ctx *Context) Done() <-chan struct{} {
+	return ctx.Req.Context().Done()
+}
+
+// Err is part of the interface for context.Context and we pass this to the request context
+func (ctx *Context) Err() error {
+	return ctx.Req.Context().Err()
+}
+
+// Value is part of the interface for context.Context and we pass this to the request context
+func (ctx *Context) Value(key interface{}) interface{} {
+	return ctx.Req.Context().Value(key)
+}
+
 // Handler represents a custom handler
 type Handler func(*Context)
 
diff --git a/modules/graceful/server_http.go b/modules/graceful/server_http.go
index b101a10d91..4471e379ef 100644
--- a/modules/graceful/server_http.go
+++ b/modules/graceful/server_http.go
@@ -5,7 +5,9 @@
 package graceful
 
 import (
+	"context"
 	"crypto/tls"
+	"net"
 	"net/http"
 )
 
@@ -16,6 +18,7 @@ func newHTTPServer(network, address, name string, handler http.Handler) (*Server
 		WriteTimeout:   DefaultWriteTimeOut,
 		MaxHeaderBytes: DefaultMaxHeaderBytes,
 		Handler:        handler,
+		BaseContext:    func(net.Listener) context.Context { return GetManager().HammerContext() },
 	}
 	server.OnShutdown = func() {
 		httpServer.SetKeepAlivesEnabled(false)
diff --git a/routers/admin/users.go b/routers/admin/users.go
index 3b29eeefc1..a71a11dd8a 100644
--- a/routers/admin/users.go
+++ b/routers/admin/users.go
@@ -113,7 +113,7 @@ func NewUserPost(ctx *context.Context) {
 			ctx.RenderWithErr(password.BuildComplexityError(ctx), tplUserNew, &form)
 			return
 		}
-		pwned, err := password.IsPwned(ctx.Req.Context(), form.Password)
+		pwned, err := password.IsPwned(ctx, form.Password)
 		if pwned {
 			ctx.Data["Err_Password"] = true
 			errMsg := ctx.Tr("auth.password_pwned")
@@ -256,7 +256,7 @@ func EditUserPost(ctx *context.Context) {
 			ctx.RenderWithErr(password.BuildComplexityError(ctx), tplUserEdit, &form)
 			return
 		}
-		pwned, err := password.IsPwned(ctx.Req.Context(), form.Password)
+		pwned, err := password.IsPwned(ctx, form.Password)
 		if pwned {
 			ctx.Data["Err_Password"] = true
 			errMsg := ctx.Tr("auth.password_pwned")
diff --git a/routers/api/v1/admin/user.go b/routers/api/v1/admin/user.go
index 2d4a3815f4..4bbe7f77ba 100644
--- a/routers/api/v1/admin/user.go
+++ b/routers/api/v1/admin/user.go
@@ -88,7 +88,7 @@ func CreateUser(ctx *context.APIContext) {
 		ctx.Error(http.StatusBadRequest, "PasswordComplexity", err)
 		return
 	}
-	pwned, err := password.IsPwned(ctx.Req.Context(), form.Password)
+	pwned, err := password.IsPwned(ctx, form.Password)
 	if pwned {
 		if err != nil {
 			log.Error(err.Error())
@@ -162,7 +162,7 @@ func EditUser(ctx *context.APIContext) {
 			ctx.Error(http.StatusBadRequest, "PasswordComplexity", err)
 			return
 		}
-		pwned, err := password.IsPwned(ctx.Req.Context(), form.Password)
+		pwned, err := password.IsPwned(ctx, form.Password)
 		if pwned {
 			if err != nil {
 				log.Error(err.Error())
diff --git a/routers/events/events.go b/routers/events/events.go
index 2c1034038f..b140bf660c 100644
--- a/routers/events/events.go
+++ b/routers/events/events.go
@@ -42,7 +42,7 @@ func Events(ctx *context.Context) {
 	}
 
 	// Listen to connection close and un-register messageChan
-	notify := ctx.Req.Context().Done()
+	notify := ctx.Done()
 	ctx.Resp.Flush()
 
 	shutdownCtx := graceful.GetManager().ShutdownContext()
diff --git a/routers/install.go b/routers/install.go
index ef53422c4e..30340e99cd 100644
--- a/routers/install.go
+++ b/routers/install.go
@@ -400,7 +400,7 @@ func InstallPost(ctx *context.Context) {
 	}
 
 	// Re-read settings
-	PostInstallInit(ctx.Req.Context())
+	PostInstallInit(ctx)
 
 	// Create admin account
 	if len(form.AdminName) > 0 {
@@ -454,7 +454,7 @@ func InstallPost(ctx *context.Context) {
 
 	// Now get the http.Server from this request and shut it down
 	// NB: This is not our hammerable graceful shutdown this is http.Server.Shutdown
-	srv := ctx.Req.Context().Value(http.ServerContextKey).(*http.Server)
+	srv := ctx.Value(http.ServerContextKey).(*http.Server)
 	go func() {
 		if err := srv.Shutdown(graceful.GetManager().HammerContext()); err != nil {
 			log.Error("Unable to shutdown the install server! Error: %v", err)
diff --git a/routers/private/manager.go b/routers/private/manager.go
index 192c4947e7..1ccb184363 100644
--- a/routers/private/manager.go
+++ b/routers/private/manager.go
@@ -35,7 +35,7 @@ func FlushQueues(ctx *context.PrivateContext) {
 		})
 		return
 	}
-	err := queue.GetManager().FlushAll(ctx.Req.Context(), opts.Timeout)
+	err := queue.GetManager().FlushAll(ctx, opts.Timeout)
 	if err != nil {
 		ctx.JSON(http.StatusRequestTimeout, map[string]interface{}{
 			"err": fmt.Sprintf("%v", err),
diff --git a/routers/private/restore_repo.go b/routers/private/restore_repo.go
index c002de874a..df787e1b33 100644
--- a/routers/private/restore_repo.go
+++ b/routers/private/restore_repo.go
@@ -36,7 +36,7 @@ func RestoreRepo(ctx *myCtx.PrivateContext) {
 	}
 
 	if err := migrations.RestoreRepository(
-		ctx.Req.Context(),
+		ctx,
 		params.RepoDir,
 		params.OwnerName,
 		params.RepoName,
diff --git a/routers/repo/blame.go b/routers/repo/blame.go
index f5b228bdfe..1a3e1dcb9c 100644
--- a/routers/repo/blame.go
+++ b/routers/repo/blame.go
@@ -124,7 +124,7 @@ func RefBlame(ctx *context.Context) {
 		return
 	}
 
-	blameReader, err := git.CreateBlameReader(ctx.Req.Context(), models.RepoPath(userName, repoName), commitID, fileName)
+	blameReader, err := git.CreateBlameReader(ctx, models.RepoPath(userName, repoName), commitID, fileName)
 	if err != nil {
 		ctx.NotFound("CreateBlameReader", err)
 		return
diff --git a/routers/repo/lfs.go b/routers/repo/lfs.go
index 3a7ce2e23b..c17bd2f87a 100644
--- a/routers/repo/lfs.go
+++ b/routers/repo/lfs.go
@@ -414,7 +414,7 @@ func LFSPointerFiles(ctx *context.Context) {
 	err = func() error {
 		pointerChan := make(chan lfs.PointerBlob)
 		errChan := make(chan error, 1)
-		go lfs.SearchPointerBlobs(ctx.Req.Context(), ctx.Repo.GitRepo, pointerChan, errChan)
+		go lfs.SearchPointerBlobs(ctx, ctx.Repo.GitRepo, pointerChan, errChan)
 
 		numPointers := 0
 		var numAssociated, numNoExist, numAssociatable int
diff --git a/routers/user/auth.go b/routers/user/auth.go
index 5f8b1a6b99..827b7cdef0 100644
--- a/routers/user/auth.go
+++ b/routers/user/auth.go
@@ -1011,9 +1011,9 @@ func LinkAccountPostRegister(ctx *context.Context) {
 		case setting.ImageCaptcha:
 			valid = context.GetImageCaptcha().VerifyReq(ctx.Req)
 		case setting.ReCaptcha:
-			valid, err = recaptcha.Verify(ctx.Req.Context(), form.GRecaptchaResponse)
+			valid, err = recaptcha.Verify(ctx, form.GRecaptchaResponse)
 		case setting.HCaptcha:
-			valid, err = hcaptcha.Verify(ctx.Req.Context(), form.HcaptchaResponse)
+			valid, err = hcaptcha.Verify(ctx, form.HcaptchaResponse)
 		default:
 			ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType))
 			return
@@ -1153,9 +1153,9 @@ func SignUpPost(ctx *context.Context) {
 		case setting.ImageCaptcha:
 			valid = context.GetImageCaptcha().VerifyReq(ctx.Req)
 		case setting.ReCaptcha:
-			valid, err = recaptcha.Verify(ctx.Req.Context(), form.GRecaptchaResponse)
+			valid, err = recaptcha.Verify(ctx, form.GRecaptchaResponse)
 		case setting.HCaptcha:
-			valid, err = hcaptcha.Verify(ctx.Req.Context(), form.HcaptchaResponse)
+			valid, err = hcaptcha.Verify(ctx, form.HcaptchaResponse)
 		default:
 			ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType))
 			return
@@ -1191,7 +1191,7 @@ func SignUpPost(ctx *context.Context) {
 		ctx.RenderWithErr(password.BuildComplexityError(ctx), tplSignUp, &form)
 		return
 	}
-	pwned, err := password.IsPwned(ctx.Req.Context(), form.Password)
+	pwned, err := password.IsPwned(ctx, form.Password)
 	if pwned {
 		errMsg := ctx.Tr("auth.password_pwned")
 		if err != nil {
@@ -1620,7 +1620,7 @@ func ResetPasswdPost(ctx *context.Context) {
 		ctx.Data["Err_Password"] = true
 		ctx.RenderWithErr(password.BuildComplexityError(ctx), tplResetPassword, nil)
 		return
-	} else if pwned, err := password.IsPwned(ctx.Req.Context(), passwd); pwned || err != nil {
+	} else if pwned, err := password.IsPwned(ctx, passwd); pwned || err != nil {
 		errMsg := ctx.Tr("auth.password_pwned")
 		if err != nil {
 			log.Error(err.Error())
diff --git a/routers/user/auth_openid.go b/routers/user/auth_openid.go
index b1dfc6ada0..1a73a08c48 100644
--- a/routers/user/auth_openid.go
+++ b/routers/user/auth_openid.go
@@ -385,13 +385,13 @@ func RegisterOpenIDPost(ctx *context.Context) {
 				ctx.ServerError("", err)
 				return
 			}
-			valid, err = recaptcha.Verify(ctx.Req.Context(), form.GRecaptchaResponse)
+			valid, err = recaptcha.Verify(ctx, form.GRecaptchaResponse)
 		case setting.HCaptcha:
 			if err := ctx.Req.ParseForm(); err != nil {
 				ctx.ServerError("", err)
 				return
 			}
-			valid, err = hcaptcha.Verify(ctx.Req.Context(), form.HcaptchaResponse)
+			valid, err = hcaptcha.Verify(ctx, form.HcaptchaResponse)
 		default:
 			ctx.ServerError("Unknown Captcha Type", fmt.Errorf("Unknown Captcha Type: %s", setting.Service.CaptchaType))
 			return
diff --git a/routers/user/setting/account.go b/routers/user/setting/account.go
index e12d63ee02..48ab37d936 100644
--- a/routers/user/setting/account.go
+++ b/routers/user/setting/account.go
@@ -58,7 +58,7 @@ func AccountPost(ctx *context.Context) {
 		ctx.Flash.Error(ctx.Tr("form.password_not_match"))
 	} else if !password.IsComplexEnough(form.Password) {
 		ctx.Flash.Error(password.BuildComplexityError(ctx))
-	} else if pwned, err := password.IsPwned(ctx.Req.Context(), form.Password); pwned || err != nil {
+	} else if pwned, err := password.IsPwned(ctx, form.Password); pwned || err != nil {
 		errMsg := ctx.Tr("auth.password_pwned")
 		if err != nil {
 			log.Error(err.Error())
diff --git a/services/archiver/archiver.go b/services/archiver/archiver.go
index 359fc8b627..dfa6334d95 100644
--- a/services/archiver/archiver.go
+++ b/services/archiver/archiver.go
@@ -76,7 +76,7 @@ func (aReq *ArchiveRequest) IsComplete() bool {
 func (aReq *ArchiveRequest) WaitForCompletion(ctx *context.Context) bool {
 	select {
 	case <-aReq.cchan:
-	case <-ctx.Req.Context().Done():
+	case <-ctx.Done():
 	}
 
 	return aReq.IsComplete()
@@ -92,7 +92,7 @@ func (aReq *ArchiveRequest) TimedWaitForCompletion(ctx *context.Context, dur tim
 	case <-time.After(dur):
 		timeout = true
 	case <-aReq.cchan:
-	case <-ctx.Req.Context().Done():
+	case <-ctx.Done():
 	}
 
 	return aReq.IsComplete(), timeout