From 62d3e5255f4ee6625f31d6fba2b129297c941cf7 Mon Sep 17 00:00:00 2001
From: wxiaoguang <wxiaoguang@gmail.com>
Date: Mon, 19 Feb 2024 01:39:04 +0800
Subject: [PATCH] Port "Use general token signing secret"

Port of https://github.com/go-gitea/gitea/pull/29205

Use a clearly defined "signing secret" for token signing.

(cherry picked from commit 8be198cdef0a486f417663b1fd6878458d7e5d92)
---
 modules/base/tool.go                         |  2 +-
 modules/context/context.go                   |  3 +-
 modules/generate/generate.go                 | 15 ++++++++
 modules/generate/generate_test.go            | 34 ++++++++++++++++++
 modules/setting/lfs.go                       | 23 ++++++------
 modules/setting/oauth2.go                    | 38 +++++++++++++++-----
 modules/setting/oauth2_test.go               | 34 ++++++++++++++++++
 modules/util/util.go                         | 11 ------
 modules/util/util_test.go                    | 14 --------
 services/actions/auth.go                     |  4 +--
 services/actions/auth_test.go                |  2 +-
 services/auth/source/oauth2/jwtsigningkey.go |  8 +----
 services/packages/auth.go                    |  4 +--
 13 files changed, 131 insertions(+), 61 deletions(-)
 create mode 100644 modules/generate/generate_test.go
 create mode 100644 modules/setting/oauth2_test.go

diff --git a/modules/base/tool.go b/modules/base/tool.go
index b72f3a1930..168a2220b2 100644
--- a/modules/base/tool.go
+++ b/modules/base/tool.go
@@ -115,7 +115,7 @@ func CreateTimeLimitCode(data string, minutes int, startInf any) string {
 
 	// create sha1 encode string
 	sh := sha1.New()
-	_, _ = sh.Write([]byte(fmt.Sprintf("%s%s%s%s%d", data, setting.SecretKey, startStr, endStr, minutes)))
+	_, _ = sh.Write([]byte(fmt.Sprintf("%s%s%s%s%d", data, hex.EncodeToString(setting.GetGeneralTokenSigningSecret()), startStr, endStr, minutes)))
 	encoded := hex.EncodeToString(sh.Sum(nil))
 
 	code := fmt.Sprintf("%s%06d%s", startStr, minutes, encoded)
diff --git a/modules/context/context.go b/modules/context/context.go
index 4d367b3242..66732eaa8a 100644
--- a/modules/context/context.go
+++ b/modules/context/context.go
@@ -6,6 +6,7 @@ package context
 
 import (
 	"context"
+	"encoding/hex"
 	"fmt"
 	"html/template"
 	"io"
@@ -124,7 +125,7 @@ func NewWebContext(base *Base, render Render, session session.Store) *Context {
 func Contexter() func(next http.Handler) http.Handler {
 	rnd := templates.HTMLRenderer()
 	csrfOpts := CsrfOptions{
-		Secret:         setting.SecretKey,
+		Secret:         hex.EncodeToString(setting.GetGeneralTokenSigningSecret()),
 		Cookie:         setting.CSRFCookieName,
 		SetCookie:      true,
 		Secure:         setting.SessionConfig.Secure,
diff --git a/modules/generate/generate.go b/modules/generate/generate.go
index df3e2474f9..41a6aa2815 100644
--- a/modules/generate/generate.go
+++ b/modules/generate/generate.go
@@ -7,6 +7,7 @@ package generate
 import (
 	"crypto/rand"
 	"encoding/base64"
+	"fmt"
 	"io"
 	"time"
 
@@ -38,6 +39,20 @@ func NewInternalToken() (string, error) {
 	return internalToken, nil
 }
 
+const defaultJwtSecretLen = 32
+
+// DecodeJwtSecret decodes a base64 encoded jwt secret into bytes, and check its length
+func DecodeJwtSecret(src string) ([]byte, error) {
+	encoding := base64.RawURLEncoding
+	decoded := make([]byte, encoding.DecodedLen(len(src))+3)
+	if n, err := encoding.Decode(decoded, []byte(src)); err != nil {
+		return nil, err
+	} else if n != defaultJwtSecretLen {
+		return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, defaultJwtSecretLen)
+	}
+	return decoded[:defaultJwtSecretLen], nil
+}
+
 // NewJwtSecret generates a new base64 encoded value intended to be used for JWT secrets.
 func NewJwtSecret() ([]byte, string, error) {
 	bytes := make([]byte, 32)
diff --git a/modules/generate/generate_test.go b/modules/generate/generate_test.go
new file mode 100644
index 0000000000..7d023b23ad
--- /dev/null
+++ b/modules/generate/generate_test.go
@@ -0,0 +1,34 @@
+// Copyright 2024 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package generate
+
+import (
+	"encoding/base64"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestDecodeJwtSecret(t *testing.T) {
+	_, err := DecodeJwtSecret("abcd")
+	assert.ErrorContains(t, err, "invalid base64 decoded length")
+	_, err = DecodeJwtSecret(strings.Repeat("a", 64))
+	assert.ErrorContains(t, err, "invalid base64 decoded length")
+
+	str32 := strings.Repeat("x", 32)
+	encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32))
+	decoded32, err := DecodeJwtSecret(encoded32)
+	assert.NoError(t, err)
+	assert.Equal(t, str32, string(decoded32))
+}
+
+func TestNewJwtSecret(t *testing.T) {
+	secret, encoded, err := NewJwtSecret()
+	assert.NoError(t, err)
+	assert.Len(t, secret, 32)
+	decoded, err := DecodeJwtSecret(encoded)
+	assert.NoError(t, err)
+	assert.Equal(t, secret, decoded)
+}
diff --git a/modules/setting/lfs.go b/modules/setting/lfs.go
index 7ab90669e7..750101747f 100644
--- a/modules/setting/lfs.go
+++ b/modules/setting/lfs.go
@@ -4,22 +4,19 @@
 package setting
 
 import (
-	"encoding/base64"
 	"fmt"
 	"time"
 
 	"code.gitea.io/gitea/modules/generate"
-	"code.gitea.io/gitea/modules/util"
 )
 
 // LFS represents the configuration for Git LFS
 var LFS = struct {
-	StartServer     bool          `ini:"LFS_START_SERVER"`
-	JWTSecretBase64 string        `ini:"LFS_JWT_SECRET"`
-	JWTSecretBytes  []byte        `ini:"-"`
-	HTTPAuthExpiry  time.Duration `ini:"LFS_HTTP_AUTH_EXPIRY"`
-	MaxFileSize     int64         `ini:"LFS_MAX_FILE_SIZE"`
-	LocksPagingNum  int           `ini:"LFS_LOCKS_PAGING_NUM"`
+	StartServer    bool          `ini:"LFS_START_SERVER"`
+	JWTSecretBytes []byte        `ini:"-"`
+	HTTPAuthExpiry time.Duration `ini:"LFS_HTTP_AUTH_EXPIRY"`
+	MaxFileSize    int64         `ini:"LFS_MAX_FILE_SIZE"`
+	LocksPagingNum int           `ini:"LFS_LOCKS_PAGING_NUM"`
 
 	Storage *Storage
 }{}
@@ -61,10 +58,10 @@ func loadLFSFrom(rootCfg ConfigProvider) error {
 		return nil
 	}
 
-	LFS.JWTSecretBase64 = loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET")
-	LFS.JWTSecretBytes, err = util.Base64FixedDecode(base64.RawURLEncoding, []byte(LFS.JWTSecretBase64), 32)
+	jwtSecretBase64 := loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET")
+	LFS.JWTSecretBytes, err = generate.DecodeJwtSecret(jwtSecretBase64)
 	if err != nil {
-		LFS.JWTSecretBytes, LFS.JWTSecretBase64, err = generate.NewJwtSecret()
+		LFS.JWTSecretBytes, jwtSecretBase64, err = generate.NewJwtSecret()
 		if err != nil {
 			return fmt.Errorf("error generating JWT Secret for custom config: %v", err)
 		}
@@ -74,8 +71,8 @@ func loadLFSFrom(rootCfg ConfigProvider) error {
 		if err != nil {
 			return fmt.Errorf("error saving JWT Secret for custom config: %v", err)
 		}
-		rootCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(LFS.JWTSecretBase64)
-		saveCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(LFS.JWTSecretBase64)
+		rootCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(jwtSecretBase64)
+		saveCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(jwtSecretBase64)
 		if err := saveCfg.Save(); err != nil {
 			return fmt.Errorf("error saving JWT Secret for custom config: %v", err)
 		}
diff --git a/modules/setting/oauth2.go b/modules/setting/oauth2.go
index e93ce188df..d3c4d5c387 100644
--- a/modules/setting/oauth2.go
+++ b/modules/setting/oauth2.go
@@ -4,13 +4,12 @@
 package setting
 
 import (
-	"encoding/base64"
 	"math"
 	"path/filepath"
+	"sync/atomic"
 
 	"code.gitea.io/gitea/modules/generate"
 	"code.gitea.io/gitea/modules/log"
-	"code.gitea.io/gitea/modules/util"
 )
 
 // OAuth2UsernameType is enum describing the way gitea 'name' should be generated from oauth2 data
@@ -98,7 +97,6 @@ var OAuth2 = struct {
 	RefreshTokenExpirationTime int64
 	InvalidateRefreshTokens    bool
 	JWTSigningAlgorithm        string `ini:"JWT_SIGNING_ALGORITHM"`
-	JWTSecretBase64            string `ini:"JWT_SECRET"`
 	JWTSigningPrivateKeyFile   string `ini:"JWT_SIGNING_PRIVATE_KEY_FILE"`
 	MaxTokenLength             int
 	DefaultApplications        []string
@@ -130,28 +128,50 @@ func loadOAuth2From(rootCfg ConfigProvider) {
 		return
 	}
 
-	OAuth2.JWTSecretBase64 = loadSecret(sec, "JWT_SECRET_URI", "JWT_SECRET")
+	jwtSecretBase64 := loadSecret(sec, "JWT_SECRET_URI", "JWT_SECRET")
 
 	if !filepath.IsAbs(OAuth2.JWTSigningPrivateKeyFile) {
 		OAuth2.JWTSigningPrivateKeyFile = filepath.Join(AppDataPath, OAuth2.JWTSigningPrivateKeyFile)
 	}
 
 	if InstallLock {
-		if _, err := util.Base64FixedDecode(base64.RawURLEncoding, []byte(OAuth2.JWTSecretBase64), 32); err != nil {
-			_, OAuth2.JWTSecretBase64, err = generate.NewJwtSecret()
+		jwtSecretBytes, err := generate.DecodeJwtSecret(jwtSecretBase64)
+		if err != nil {
+			jwtSecretBytes, jwtSecretBase64, err = generate.NewJwtSecret()
 			if err != nil {
 				log.Fatal("error generating JWT secret: %v", err)
 			}
-
 			saveCfg, err := rootCfg.PrepareSaving()
 			if err != nil {
 				log.Fatal("save oauth2.JWT_SECRET failed: %v", err)
 			}
-			rootCfg.Section("oauth2").Key("JWT_SECRET").SetValue(OAuth2.JWTSecretBase64)
-			saveCfg.Section("oauth2").Key("JWT_SECRET").SetValue(OAuth2.JWTSecretBase64)
+			rootCfg.Section("oauth2").Key("JWT_SECRET").SetValue(jwtSecretBase64)
+			saveCfg.Section("oauth2").Key("JWT_SECRET").SetValue(jwtSecretBase64)
 			if err := saveCfg.Save(); err != nil {
 				log.Fatal("save oauth2.JWT_SECRET failed: %v", err)
 			}
 		}
+		generalSigningSecret.Store(&jwtSecretBytes)
 	}
 }
+
+// generalSigningSecret is used as container for a []byte value
+// instead of an additional mutex, we use CompareAndSwap func to change the value thread save
+var generalSigningSecret atomic.Pointer[[]byte]
+
+func GetGeneralTokenSigningSecret() []byte {
+	old := generalSigningSecret.Load()
+	if old == nil || len(*old) == 0 {
+		jwtSecret, _, err := generate.NewJwtSecret()
+		if err != nil {
+			log.Fatal("Unable to generate general JWT secret: %s", err.Error())
+		}
+		if generalSigningSecret.CompareAndSwap(old, &jwtSecret) {
+			// FIXME: in main branch, the signing token should be refactored (eg: one unique for LFS/OAuth2/etc ...)
+			log.Warn("OAuth2 is not enabled, unable to use a persistent signing secret, a new one is generated, which is not persistent between restarts and cluster nodes")
+			return jwtSecret
+		}
+		return *generalSigningSecret.Load()
+	}
+	return *old
+}
diff --git a/modules/setting/oauth2_test.go b/modules/setting/oauth2_test.go
new file mode 100644
index 0000000000..da36d100aa
--- /dev/null
+++ b/modules/setting/oauth2_test.go
@@ -0,0 +1,34 @@
+// Copyright 2024 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package setting
+
+import (
+	"testing"
+
+	"code.gitea.io/gitea/modules/generate"
+	"code.gitea.io/gitea/modules/test"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestGetGeneralSigningSecret(t *testing.T) {
+	// when there is no general signing secret, it should be generated, and keep the same value
+	assert.Nil(t, generalSigningSecret.Load())
+	s1 := GetGeneralTokenSigningSecret()
+	assert.NotNil(t, s1)
+	s2 := GetGeneralTokenSigningSecret()
+	assert.Equal(t, s1, s2)
+
+	// the config value should always override any pre-generated value
+	cfg, _ := NewConfigProviderFromData(`
+[oauth2]
+JWT_SECRET = BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
+`)
+	defer test.MockVariableValue(&InstallLock, true)()
+	loadOAuth2From(cfg)
+	actual := GetGeneralTokenSigningSecret()
+	expected, _ := generate.DecodeJwtSecret("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
+	assert.Len(t, actual, 32)
+	assert.EqualValues(t, expected, actual)
+}
diff --git a/modules/util/util.go b/modules/util/util.go
index c47931f6c9..0e5c6a4e64 100644
--- a/modules/util/util.go
+++ b/modules/util/util.go
@@ -6,7 +6,6 @@ package util
 import (
 	"bytes"
 	"crypto/rand"
-	"encoding/base64"
 	"fmt"
 	"math/big"
 	"strconv"
@@ -246,13 +245,3 @@ func ToFloat64(number any) (float64, error) {
 func ToPointer[T any](val T) *T {
 	return &val
 }
-
-func Base64FixedDecode(encoding *base64.Encoding, src []byte, length int) ([]byte, error) {
-	decoded := make([]byte, encoding.DecodedLen(len(src))+3)
-	if n, err := encoding.Decode(decoded, src); err != nil {
-		return nil, err
-	} else if n != length {
-		return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, length)
-	}
-	return decoded[:length], nil
-}
diff --git a/modules/util/util_test.go b/modules/util/util_test.go
index 8509d8aced..c5830ce01c 100644
--- a/modules/util/util_test.go
+++ b/modules/util/util_test.go
@@ -4,7 +4,6 @@
 package util
 
 import (
-	"encoding/base64"
 	"regexp"
 	"strings"
 	"testing"
@@ -234,16 +233,3 @@ func TestToPointer(t *testing.T) {
 	val123 := 123
 	assert.False(t, &val123 == ToPointer(val123))
 }
-
-func TestBase64FixedDecode(t *testing.T) {
-	_, err := Base64FixedDecode(base64.RawURLEncoding, []byte("abcd"), 32)
-	assert.ErrorContains(t, err, "invalid base64 decoded length")
-	_, err = Base64FixedDecode(base64.RawURLEncoding, []byte(strings.Repeat("a", 64)), 32)
-	assert.ErrorContains(t, err, "invalid base64 decoded length")
-
-	str32 := strings.Repeat("x", 32)
-	encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32))
-	decoded32, err := Base64FixedDecode(base64.RawURLEncoding, []byte(encoded32), 32)
-	assert.NoError(t, err)
-	assert.Equal(t, str32, string(decoded32))
-}
diff --git a/services/actions/auth.go b/services/actions/auth.go
index 53e68f0b71..e0f9a9015d 100644
--- a/services/actions/auth.go
+++ b/services/actions/auth.go
@@ -38,7 +38,7 @@ func CreateAuthorizationToken(taskID, runID, jobID int64) (string, error) {
 	}
 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 
-	tokenString, err := token.SignedString([]byte(setting.SecretKey))
+	tokenString, err := token.SignedString(setting.GetGeneralTokenSigningSecret())
 	if err != nil {
 		return "", err
 	}
@@ -62,7 +62,7 @@ func ParseAuthorizationToken(req *http.Request) (int64, error) {
 		if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
 			return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
 		}
-		return []byte(setting.SecretKey), nil
+		return setting.GetGeneralTokenSigningSecret(), nil
 	})
 	if err != nil {
 		return 0, err
diff --git a/services/actions/auth_test.go b/services/actions/auth_test.go
index f6288ccd5a..1f62f17f52 100644
--- a/services/actions/auth_test.go
+++ b/services/actions/auth_test.go
@@ -20,7 +20,7 @@ func TestCreateAuthorizationToken(t *testing.T) {
 	assert.NotEqual(t, "", token)
 	claims := jwt.MapClaims{}
 	_, err = jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (interface{}, error) {
-		return []byte(setting.SecretKey), nil
+		return setting.GetGeneralTokenSigningSecret(), nil
 	})
 	assert.Nil(t, err)
 	scp, ok := claims["scp"]
diff --git a/services/auth/source/oauth2/jwtsigningkey.go b/services/auth/source/oauth2/jwtsigningkey.go
index eca0b8b7e1..070fffe60f 100644
--- a/services/auth/source/oauth2/jwtsigningkey.go
+++ b/services/auth/source/oauth2/jwtsigningkey.go
@@ -300,7 +300,7 @@ func InitSigningKey() error {
 	case "HS384":
 		fallthrough
 	case "HS512":
-		key, err = loadSymmetricKey()
+		key = setting.GetGeneralTokenSigningSecret()
 	case "RS256":
 		fallthrough
 	case "RS384":
@@ -333,12 +333,6 @@ func InitSigningKey() error {
 	return nil
 }
 
-// loadSymmetricKey checks if the configured secret is valid.
-// If it is not valid, it will return an error.
-func loadSymmetricKey() (any, error) {
-	return util.Base64FixedDecode(base64.RawURLEncoding, []byte(setting.OAuth2.JWTSecretBase64), 32)
-}
-
 // loadOrCreateAsymmetricKey checks if the configured private key exists.
 // If it does not exist a new random key gets generated and saved on the configured path.
 func loadOrCreateAsymmetricKey() (any, error) {
diff --git a/services/packages/auth.go b/services/packages/auth.go
index 2f78b26f50..8263c28bed 100644
--- a/services/packages/auth.go
+++ b/services/packages/auth.go
@@ -33,7 +33,7 @@ func CreateAuthorizationToken(u *user_model.User) (string, error) {
 	}
 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 
-	tokenString, err := token.SignedString([]byte(setting.SecretKey))
+	tokenString, err := token.SignedString(setting.GetGeneralTokenSigningSecret())
 	if err != nil {
 		return "", err
 	}
@@ -57,7 +57,7 @@ func ParseAuthorizationToken(req *http.Request) (int64, error) {
 		if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
 			return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
 		}
-		return []byte(setting.SecretKey), nil
+		return setting.GetGeneralTokenSigningSecret(), nil
 	})
 	if err != nil {
 		return 0, err