Skip to content

Commit 1c675e5

Browse files
authored
feat: remove the using of pickle(bkpaas-auth) (#266)
1 parent eda6815 commit 1c675e5

12 files changed

Lines changed: 359 additions & 109 deletions

File tree

sdks/bkpaas-auth/AGENTS.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
## Context
2+
3+
You are in the bkpaas-auth repo, helping implement features, fix bugs, and refactor existing code.
4+
5+
## Source code
6+
7+
* bkpaas-auth is a Django app that helps implement user authentication, used by BlueKing systems.
8+
* The main project in `bkpaas_auth/`.
9+
* Unit tests are placed in 'tests/' directory, following pytest conventions.
10+
11+
## Coding style
12+
13+
* For Python files, follow PEP-8.
14+
* For Python files, run `ruff format` to format after edits.
15+
16+
## Common workflows
17+
18+
### Running tests
19+
20+
* Run all tests: `poetry run pytest -s`
21+
* Run some tests: `poetry run pytest -s tests/filename.py`
22+
* ALWAYS prefer specifying test files for efficiency.

sdks/bkpaas-auth/bkpaas_auth/backends.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# -*- coding: utf-8 -*-
22
import inspect
33
import logging
4-
import pickle
54
from typing import Dict, Optional, Union
65

76
from django.conf import settings
87
from django.contrib.auth import get_user_model
98
from django.contrib.auth.models import AnonymousUser
109
from django.core.exceptions import ImproperlyConfigured
1110
from django.http import HttpRequest
12-
from django.utils.encoding import force_bytes
1311

1412
from bkpaas_auth.conf import bkauth_settings
1513
from bkpaas_auth.core.constants import ProviderType
@@ -108,11 +106,15 @@ def get_token_from_session(self, request: HttpRequest) -> Optional[LoginToken]:
108106
if "user_token" not in request.session:
109107
return None
110108

109+
raw_user_token = request.session["user_token"]
110+
if not isinstance(raw_user_token, str) or not raw_user_token.startswith("{"):
111+
logger.warning("ignore legacy or invalid session user_token payload")
112+
return None
113+
111114
try:
112-
user_token_pickled = force_bytes(request.session["user_token"], "latin1")
113-
user_token: LoginToken = pickle.loads(user_token_pickled)
115+
user_token: LoginToken = LoginToken.parse_json(raw_user_token)
114116
except Exception:
115-
logger.exception("pickle loads user_token failed")
117+
logger.exception("deserialize user_token failed")
116118
return None
117119

118120
# token 已经过期则不返回,否则会出现 403

sdks/bkpaas-auth/bkpaas_auth/core/token.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77
from abc import abstractmethod
8-
from typing import NamedTuple, Optional
8+
from typing import Any, ClassVar, NamedTuple, Optional, Tuple
99

1010
from django.utils.timezone import now
1111
from django.utils.translation import get_language
@@ -23,10 +23,21 @@
2323
from bkpaas_auth.core.services import get_app_credentials
2424
from bkpaas_auth.core.user_info import BkUserInfo, RtxUserInfo, UserInfo
2525
from bkpaas_auth.models import User
26-
from bkpaas_auth.utils import scrub_data
26+
from bkpaas_auth.utils import deserialize_datetime, scrub_data, serialize_datetime
2727

2828
logger = logging.getLogger(__name__)
2929

30+
# Constants related with serialization of UserInfo in LoginToken.
31+
#
32+
# The field name to store the type of user_info instance.
33+
_USER_INFO_TYPE_FIELD = "_user_info_type"
34+
# The mapping between user_info type name and the actual class.
35+
_USER_INFO_TYPES: dict[str, type[UserInfo]] = {
36+
UserInfo.__name__: UserInfo,
37+
RtxUserInfo.__name__: RtxUserInfo,
38+
BkUserInfo.__name__: BkUserInfo,
39+
}
40+
3041

3142
class UserAccount(NamedTuple):
3243
"""
@@ -181,6 +192,12 @@ class LoginToken:
181192
"""Access token object"""
182193

183194
token_timeout_margin = 300
195+
_json_fields: ClassVar[Tuple[str, ...]] = (
196+
"login_token",
197+
"expires_at",
198+
"issued_at",
199+
"user_info",
200+
)
184201

185202
def __init__(self, login_token=None, expires_in=None):
186203
assert login_token, "Must provide token string"
@@ -200,6 +217,49 @@ def make_user(self, provider_type):
200217
self.user_info.provider_type = provider_type
201218
return create_user_from_token(self)
202219

220+
def dump_json(self) -> str:
221+
"""Serialize the token to JSON string."""
222+
user_info_type = type(self.user_info).__name__
223+
if user_info_type not in _USER_INFO_TYPES:
224+
raise TypeError(f"unsupported user info type: {user_info_type}")
225+
226+
payload = {}
227+
for field in self._json_fields:
228+
value = getattr(self, field)
229+
match field:
230+
case "expires_at" | "issued_at":
231+
value = serialize_datetime(value)
232+
case "user_info":
233+
value = json.loads(value.dump_json())
234+
payload[field] = value
235+
236+
payload[_USER_INFO_TYPE_FIELD] = user_info_type
237+
return json.dumps(payload)
238+
239+
@classmethod
240+
def parse_json(cls, payload: str | dict[str, Any]) -> "LoginToken":
241+
"""Parse the token from JSON string or dict."""
242+
if isinstance(payload, str):
243+
payload = json.loads(payload)
244+
if not isinstance(payload, dict):
245+
raise TypeError(f"serialized payload must be dict, got: {type(payload)!r}")
246+
247+
user_info_type = payload.get(_USER_INFO_TYPE_FIELD)
248+
if user_info_type not in _USER_INFO_TYPES:
249+
raise ValueError(f"unexpected serialized type: {user_info_type!r}")
250+
251+
# Bypass __init__ so the original issued/expires timestamps survive the round trip.
252+
token = cls.__new__(cls)
253+
for field in cls._json_fields:
254+
value = payload[field]
255+
match field:
256+
case "expires_at" | "issued_at":
257+
value = deserialize_datetime(value)
258+
case "user_info":
259+
value = _USER_INFO_TYPES[user_info_type].parse_json(value)
260+
setattr(token, field, value)
261+
return token
262+
203263

204264
def mocked_create_user_from_token(
205265
token: LoginToken, provider_type: int = ProviderType.RTX, username: str = bkauth_settings.MOCKED_USER_NAME

sdks/bkpaas-auth/bkpaas_auth/core/user_info.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
2-
from typing import TYPE_CHECKING
2+
import json
3+
from typing import TYPE_CHECKING, Any, ClassVar, Tuple
34

45
from bkpaas_auth.core.constants import ProviderType
56
from bkpaas_auth.core.encoder import user_id_encoder
@@ -9,44 +10,88 @@
910

1011

1112
class UserInfo:
12-
"""Base class for Userinfo"""
13+
"""Base class for UserInfo"""
1314

1415
provider_type: ProviderType
16+
_json_fields: ClassVar[Tuple[str, ...]] = (
17+
"provider_type",
18+
"username",
19+
"display_name",
20+
"time_zone",
21+
"tenant_id",
22+
)
1523

1624
def __init__(self, username, **kwargs):
1725
self.username = username
1826
self.display_name = kwargs.get("display_name") or username
1927
self.time_zone = kwargs.get("time_zone")
2028
self.tenant_id = kwargs.get("tenant_id")
2129

22-
def provide(self, user: 'User'):
30+
def provide(self, user: "User"):
2331
user.provider_type = self.provider_type
2432
user.username = self.username
2533
user.bkpaas_user_id = user_id_encoder.encode(self.provider_type, self.username)
2634

2735
user.update_user_info(self.__dict__)
2836
return user
2937

38+
def dump_json(self) -> str:
39+
payload = {}
40+
for field in self._json_fields:
41+
value = getattr(self, field, None)
42+
if field == "provider_type" and value is not None:
43+
value = int(value)
44+
payload[field] = value
45+
return json.dumps(payload)
46+
47+
@classmethod
48+
def parse_json(cls, payload: str | dict[str, Any]) -> "UserInfo":
49+
data = cls._parse_json_payload(payload)
50+
user_info = cls.__new__(cls)
51+
for field in cls._json_fields:
52+
value = data.get(field)
53+
if field == "provider_type" and value is not None:
54+
value = ProviderType(value)
55+
setattr(user_info, field, value)
56+
return user_info
57+
3058
def __eq__(self, other):
3159
if not isinstance(other, UserInfo):
3260
return False
3361
return self.username == other.username
3462

63+
@staticmethod
64+
def _parse_json_payload(payload: str | dict[str, Any]) -> dict[str, Any]:
65+
"""Parse the JSON payload to dict."""
66+
if isinstance(payload, str):
67+
payload = json.loads(payload)
68+
if not isinstance(payload, dict):
69+
raise TypeError(f"serialized payload must be dict, got: {type(payload)!r}")
70+
return payload
71+
3572

3673
class RtxUserInfo(UserInfo):
3774
"""User info for RTX user"""
3875

3976
provider_type = ProviderType.RTX
4077
email_suffix = "@tencent.com"
4178

79+
_json_fields = UserInfo._json_fields + (
80+
"nickname",
81+
"chinese_name",
82+
"email",
83+
"phone",
84+
"avatar_url",
85+
)
86+
4287
def __init__(self, **kwargs):
4388
super().__init__(kwargs["LoginName"], **kwargs)
44-
self.nickname = kwargs['ChineseName']
45-
self.chinese_name = kwargs['ChineseName']
46-
self.email = f'{self.username}{self.email_suffix}'
89+
self.nickname = kwargs["ChineseName"]
90+
self.chinese_name = kwargs["ChineseName"]
91+
self.email = f"{self.username}{self.email_suffix}"
4792
# 用户 API 添加了限制,没有申请特殊权限的情况下无法获取手机信息
48-
self.phone = kwargs.get('MobilePhoneNumber', '')
49-
self.avatar_url = ''
93+
self.phone = kwargs.get("MobilePhoneNumber", "")
94+
self.avatar_url = ""
5095

5196
def __eq__(self, other):
5297
if not isinstance(other, RtxUserInfo):
@@ -64,16 +109,23 @@ class BkUserInfo(UserInfo):
64109
"""User info for Bk user"""
65110

66111
provider_type = ProviderType.BK
112+
_json_fields = UserInfo._json_fields + (
113+
"nickname",
114+
"chinese_name",
115+
"email",
116+
"phone",
117+
"avatar_url",
118+
)
67119

68120
def __init__(self, **kwargs):
69121
# bk_username 用户英文ID
70122
super().__init__(kwargs["bk_username"], **kwargs)
71123
# chname 用户中文名
72-
self.nickname = kwargs['chname']
73-
self.chinese_name = kwargs['chname']
74-
self.email = kwargs['email']
75-
self.phone = kwargs['phone']
76-
self.avatar_url = ''
124+
self.nickname = kwargs["chname"]
125+
self.chinese_name = kwargs["chname"]
126+
self.email = kwargs["email"]
127+
self.phone = kwargs["phone"]
128+
self.avatar_url = ""
77129

78130
def __eq__(self, other):
79131
if not isinstance(other, BkUserInfo):

sdks/bkpaas-auth/bkpaas_auth/middlewares.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
# -*- coding: utf-8 -*-
22
import json
33
import logging
4-
import pickle
54
import time
65
from typing import Dict
7-
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
86

97
from django.conf import settings
108
from django.contrib import auth
119
from django.http import HttpRequest, HttpResponse
12-
from django.utils.deprecation import MiddlewareMixin
13-
from django.utils.encoding import force_str
1410
from django.utils import timezone as dj_timezone
11+
from django.utils.deprecation import MiddlewareMixin
12+
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
1513

1614
from bkpaas_auth.backends import UniversalAuthBackend
1715
from bkpaas_auth.core.constants import ACCESS_PERMISSION_DENIED_CODE
@@ -24,7 +22,7 @@ class CookieLoginMiddleware(MiddlewareMixin):
2422
"""Call auth.login when user credential cookies changes"""
2523

2624
def process_request(self, request):
27-
assert hasattr(request, 'session'), (
25+
assert hasattr(request, "session"), (
2826
"The CookieLoginMiddleware requires session middleware "
2927
"to be installed. Edit your MIDDLEWARE%s setting to insert "
3028
"'django.contrib.sessions.middleware.SessionMiddleware' before "
@@ -44,7 +42,7 @@ def process_request(self, request):
4442
self.authenticate_and_login(request, credentials)
4543
except AccessPermissionDenied as e:
4644
resp = HttpResponse(
47-
json.dumps({'code': ACCESS_PERMISSION_DENIED_CODE, 'detail': str(e)}),
45+
json.dumps({"code": ACCESS_PERMISSION_DENIED_CODE, "detail": str(e)}),
4846
content_type="application/json",
4947
)
5048
resp.status_code = 403
@@ -57,7 +55,7 @@ def should_authenticate(
5755
) -> bool:
5856
"""Decide whether to re-authenticate current credentials or not"""
5957
# Force re-login if credentials is different from last time
60-
credentials_been_modified = credentials != request.session.get('auth_credentials', {})
58+
credentials_been_modified = credentials != request.session.get("auth_credentials", {})
6159
if credentials_been_modified:
6260
return True
6361

@@ -71,10 +69,10 @@ def authenticate_and_login(self, request: HttpRequest, credentials: Dict[str, st
7169
:params request: Current request object
7270
:params credentials: user credentials, such as uin/skey pair
7371
"""
74-
logger.debug('Authenticating credentials...')
72+
logger.debug("Authenticating credentials...")
7573
user = auth.authenticate(request=request, auth_credentials=credentials)
7674
if user is None or not user.is_authenticated:
77-
logger.info('Authentication failed, logout.')
75+
logger.info("Authentication failed, logout.")
7876
auth.logout(request)
7977
return
8078

@@ -83,13 +81,12 @@ def authenticate_and_login(self, request: HttpRequest, credentials: Dict[str, st
8381
logger.info("User is not validate by UniversalAuthBackend, skip login processes.")
8482
return
8583

86-
logger.debug('Authentication finished, username: %s', user.username)
87-
request.session['provider_type'] = user.provider_type.value
88-
request.session['bkpaas_user_id'] = user.bkpaas_user_id
89-
request.session['bkpaas_authenticated_at'] = time.time()
90-
request.session['auth_credentials'] = credentials
91-
# python3 compatibility
92-
request.session['user_token'] = force_str(pickle.dumps(user.token), 'latin1')
84+
logger.debug("Authentication finished, username: %s", user.username)
85+
request.session["provider_type"] = user.provider_type.value
86+
request.session["bkpaas_user_id"] = user.bkpaas_user_id
87+
request.session["bkpaas_authenticated_at"] = time.time()
88+
request.session["auth_credentials"] = credentials
89+
request.session["user_token"] = user.token.dump_json()
9390

9491
# Calling `auth.login` will rotate CSRF token and modify user session, only do this when the authenticated
9592
# user was different with the user stored in session. Otherwise CSRF token validation may fail due to the

sdks/bkpaas-auth/bkpaas_auth/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from typing import Any, Dict
23

34
DEFAULT_SCRUBBED_FIELDS = (
@@ -46,3 +47,15 @@ def _key_is_sensitive(key: str) -> bool:
4647
else:
4748
current_result[key] = value
4849
return result
50+
51+
52+
def serialize_datetime(value: datetime.datetime) -> str:
53+
if not isinstance(value, datetime.datetime):
54+
raise TypeError(f"datetime value required, got: {type(value)!r}")
55+
return value.isoformat()
56+
57+
58+
def deserialize_datetime(value: str) -> datetime.datetime:
59+
if not isinstance(value, str):
60+
raise TypeError(f"datetime payload must be str, got: {type(value)!r}")
61+
return datetime.datetime.fromisoformat(value)

0 commit comments

Comments
 (0)