如何在Django Rest Framework测试中强制认证

fkaflof6  于 2023-03-09  发布在  Go
关注(0)|答案(1)|浏览(74)

我在测试DRF教程中内置的API时遇到问题:https://www.django-rest-framework.org/tutorial/1-serialization/
我的看法:

class SnippetList(generics.ListCreateAPIView):
    permission_classes = [permissions.IsAuthenticatedOrReadOnly]
    queryset = Snippet.objects.all()
    serializer_class = SnippetSerializer
    def perform_create(self, serializer):
        serializer.save(owner=self.request.user)

和用于测试视图测试类:

class SnippetsList(APITestCase):
    def setUp(self):
        self.user = User.objects.create_superuser(username='testowy', password='test')
        self.client = APIClient()
        Snippet.objects.create(code="print('hello')", owner=self.user)
        Snippet.objects.create(code="print('world')", owner=self.user)
        self.payload = {
            "code": "print(edit)"
        }

    def test_get_snippets_list(self):
        response = self.client.get(reverse('snippet_list'))
        snippets = Snippet.objects.all()
        serializer = SnippetSerializer(snippets, many=True)
        self.assertEqual(response.data, serializer.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

    def test_post_snippets_list(self):
        self.client.force_authenticate(self.user)
        response = self.client.post(reverse('snippet_list'), json.dumps(self.payload), format='json')
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

当我用python运行测试时,manage.pytest第一个测试(get方法)通过,但第二个测试得到以下输出:

self.assertEqual(response.status_code, status.HTTP_201_CREATED) AssertionError: 400 != 201

在登录后的手动测试中,一切都运行完美,有人知道我在这里错过了什么吗?

iqxoj9l9

iqxoj9l91#

force_authenticate在setUp方法内部调用。
一个更好的实践可能是为PublicTests创建一个类,该类将拥有一个常规用户。
然后创建一个名为PrivateTests的新类,在该类中,您将在setUp()中执行force_authenticate

def setUp(self):
    self.user = User.objects.create_superuser(username='testowy', password='test')
    self.client = APIClient()
    self.client.force_authenticate(user=self.user)

def test_get_snippets_list(self):
    """Test creating and listing snippet objects."""

    snippet1 = Snippet.objects.create(code="print('hello')", owner=self.user)
    snippet2 = Snippet.objects.create(code="print('world')", owner=self.user)

    user = self.user
    res = self.client.get(reverse("snippet_list"))

    self.assertEqual(res.status_code, status.HTTP_200_OK)
    self.assertEqual(user.snippets.count(), 2)
    self.assertIn(snippet1, user.snippets.all())
    self.assertIn(snippet2, user.snippets.all())

相关问题