Skip to content

Commit 2646a5c

Browse files
authored
[fix] Made batch import operations atomic #551
Fixes #551
1 parent 43361fd commit 2646a5c

2 files changed

Lines changed: 59 additions & 14 deletions

File tree

openwisp_radius/base/models.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from django.core.cache import cache
1818
from django.core.exceptions import ObjectDoesNotExist, ValidationError
1919
from django.core.mail import send_mail
20-
from django.db import models
20+
from django.db import models, transaction
2121
from django.db.models import ProtectedError, Q
2222
from django.utils import timezone
2323
from django.utils.crypto import get_random_string
@@ -985,16 +985,17 @@ def clean(self):
985985
def add(self, reader, password_length=BATCH_DEFAULT_PASSWORD_LENGTH):
986986
users_list = []
987987
generated_passwords = []
988-
for row in reader:
989-
if len(row) == 5:
990-
user, password = self.get_or_create_user(
991-
row, users_list, password_length
992-
)
993-
users_list.append(user)
994-
if password:
995-
generated_passwords.append(password)
996-
for user in users_list:
997-
self.save_user(user)
988+
with transaction.atomic():
989+
for row in reader:
990+
if len(row) == 5:
991+
user, password = self.get_or_create_user(
992+
row, users_list, password_length
993+
)
994+
users_list.append(user)
995+
if password:
996+
generated_passwords.append(password)
997+
for user in users_list:
998+
self.save_user(user)
998999
for element in generated_passwords:
9991000
username, password, user_email = element
10001001
send_mail(
@@ -1012,9 +1013,10 @@ def csvfile_upload(
10121013
csv_data = csvfile.read()
10131014
csv_data = decode_byte_data(csv_data)
10141015
reader = csv.reader(StringIO(csv_data), delimiter=",")
1015-
self.full_clean()
1016-
self.save()
1017-
self.add(reader, password_length)
1016+
with transaction.atomic():
1017+
self.full_clean()
1018+
self.save()
1019+
self.add(reader, password_length)
10181020

10191021
def prefix_add(self, prefix, n, password_length=BATCH_DEFAULT_PASSWORD_LENGTH):
10201022
users_list, user_credentials = prefix_generate_users(prefix, n, password_length)

openwisp_radius/tests/test_batch_add_users.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from unittest.mock import patch
22

3+
from django.contrib.auth import get_user_model
34
from django.core.exceptions import ValidationError
5+
from django.db import IntegrityError
46
from django.urls import reverse
57

68
from ..utils import load_model
@@ -145,5 +147,46 @@ def test_verified_batch_user_creation(self):
145147
self.assertEqual(user.registered_user.method, "manual")
146148

147149

150+
class TestBatchAtomicity(FileMixin, BaseTransactionTestCase):
151+
def test_csv_upload_total_rollback(self):
152+
User = get_user_model()
153+
org = self._get_org()
154+
data = [
155+
["user_one", "pass123", "total@example.com", "John", "Doe"],
156+
["user_two", "pass123", "total@example.com", "Jane", "Doe"],
157+
]
158+
batch = RadiusBatch(
159+
name="total-rollback",
160+
strategy="csv",
161+
organization=org,
162+
csvfile=self._get_csvfile(data),
163+
)
164+
with self.assertRaises(IntegrityError):
165+
batch.csvfile_upload()
166+
self.assertFalse(RadiusBatch.objects.filter(name="total-rollback").exists())
167+
self.assertFalse(User.objects.filter(username="user_one").exists())
168+
169+
def test_add_method_internal_atomicity(self):
170+
User = get_user_model()
171+
org = self._get_org()
172+
data = [
173+
["user_one", "pass123", "duplicate@example.com", "John", "Doe"],
174+
["user_two", "pass123", "duplicate@example.com", "Jane", "Doe"],
175+
]
176+
batch = self._create_radius_batch(
177+
name="atomic-integrity-test",
178+
strategy="csv",
179+
organization=org,
180+
csvfile=self._get_csvfile(data),
181+
)
182+
with self.assertRaises(IntegrityError):
183+
batch.add(data)
184+
self.assertFalse(User.objects.filter(username="user_one").exists())
185+
self.assertTrue(
186+
RadiusBatch.objects.filter(name="atomic-integrity-test").exists()
187+
)
188+
self.assertEqual(batch.users.count(), 0)
189+
190+
148191
del BaseTestCase
149192
del BaseTransactionTestCase

0 commit comments

Comments
 (0)