diff --git a/duo_universal/client.py b/duo_universal/client.py index 0c78935..c6f6f42 100644 --- a/duo_universal/client.py +++ b/duo_universal/client.py @@ -130,7 +130,7 @@ def _create_jwt_args(self, endpoint): def __init__(self, client_id, client_secret, host, redirect_uri, duo_certs=DEFAULT_CA_CERT_PATH, use_duo_code_attribute=True, http_proxy=None, - exp_seconds=FIVE_MINUTES_IN_SECONDS): + exp_seconds=FIVE_MINUTES_IN_SECONDS, disable_ca_pinning=False): """ Initializes instance of Client class @@ -144,6 +144,10 @@ def __init__(self, client_id, client_secret, host, use_duo_code_attribute -- (Optional: default true) Flag to use `duo_code` instead of `code` for returned authorization parameter http_proxy -- (Optional) HTTP proxy to tunnel requests through exp_seconds -- (Optional) The number of seconds used for JWT expiry. Must be be at most 5 minutes. + disable_ca_pinning -- (Optional: default false) If True, uses the system's default + trusted CA certificates instead of Duo's bundled CA certificates. + TLS verification remains active. Cannot be used together with + custom duo_certs. """ self._validate_init_config(client_id, @@ -158,9 +162,16 @@ def __init__(self, client_id, client_secret, host, self._redirect_uri = redirect_uri self._use_duo_code_attribute = use_duo_code_attribute - # If duo_certs is None set it to the DEFAULT_CA_CERT_PATH - # so that we make sure we are pinning certs - if duo_certs is not None: + if disable_ca_pinning and duo_certs not in (None, DEFAULT_CA_CERT_PATH): + raise DuoException( + "Cannot both disable CA pinning and provide custom CA certificates" + ) + + self._disable_ca_pinning = disable_ca_pinning + + if disable_ca_pinning: + self._duo_certs = True + elif duo_certs is not None: if duo_certs == "DISABLE": self._duo_certs = False else: diff --git a/tests/test_setup_client.py b/tests/test_setup_client.py index 2902a91..03dd04f 100644 --- a/tests/test_setup_client.py +++ b/tests/test_setup_client.py @@ -1,3 +1,4 @@ +from unittest.mock import patch, MagicMock from duo_universal import client import unittest @@ -158,5 +159,99 @@ def test_proxy_set_off_kwargs(self): self.assertEqual(client_with_no_proxy._http_proxy, NONE) +class TestDisableCaPinning(unittest.TestCase): + + def test_default_is_pinning_enabled(self): + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI) + self.assertFalse(c._disable_ca_pinning) + self.assertEqual(c._duo_certs, client.DEFAULT_CA_CERT_PATH) + + def test_disable_ca_pinning_true(self): + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + disable_ca_pinning=True) + self.assertTrue(c._disable_ca_pinning) + self.assertTrue(c._duo_certs) + + def test_disable_ca_pinning_with_default_duo_certs(self): + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + duo_certs=client.DEFAULT_CA_CERT_PATH, disable_ca_pinning=True) + self.assertTrue(c._disable_ca_pinning) + self.assertTrue(c._duo_certs) + + def test_disable_ca_pinning_with_none_duo_certs(self): + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + duo_certs=None, disable_ca_pinning=True) + self.assertTrue(c._disable_ca_pinning) + self.assertTrue(c._duo_certs) + + def test_disable_ca_pinning_with_custom_duo_certs_raises(self): + with self.assertRaises(client.DuoException) as ctx: + client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + duo_certs=CA_CERT_NEW, disable_ca_pinning=True) + self.assertIn("Cannot both disable CA pinning", str(ctx.exception)) + + def test_disable_ca_pinning_with_disable_duo_certs_raises(self): + with self.assertRaises(client.DuoException) as ctx: + client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + duo_certs="DISABLE", disable_ca_pinning=True) + self.assertIn("Cannot both disable CA pinning", str(ctx.exception)) + + def test_disable_ca_pinning_false_preserves_existing_behavior(self): + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + disable_ca_pinning=False) + self.assertEqual(c._duo_certs, client.DEFAULT_CA_CERT_PATH) + + +class TestDisableCaPinningRequests(unittest.TestCase): + + @patch('requests.post') + def test_health_check_pinning_disabled_uses_system_trust_store(self, requests_mock): + requests_mock.return_value = MagicMock(content=b'{"stat": "OK", "response": {"timestamp": 1}}') + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + disable_ca_pinning=True) + c.health_check() + _, kwargs = requests_mock.call_args + self.assertTrue(kwargs['verify']) + self.assertIsNot(kwargs['verify'], client.DEFAULT_CA_CERT_PATH) + + @patch('requests.post') + def test_health_check_pinning_enabled_uses_bundled_certs(self, requests_mock): + requests_mock.return_value = MagicMock(content=b'{"stat": "OK", "response": {"timestamp": 1}}') + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI) + c.health_check() + _, kwargs = requests_mock.call_args + self.assertEqual(kwargs['verify'], client.DEFAULT_CA_CERT_PATH) + + @patch('requests.post') + def test_token_exchange_pinning_disabled_uses_system_trust_store(self, requests_mock): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'id_token': 'fake'} + requests_mock.return_value = mock_response + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI, + disable_ca_pinning=True) + try: + c.exchange_authorization_code_for_2fa_result('code', 'user') + except client.DuoException: + pass + _, kwargs = requests_mock.call_args + self.assertTrue(kwargs['verify']) + self.assertIsNot(kwargs['verify'], client.DEFAULT_CA_CERT_PATH) + + @patch('requests.post') + def test_token_exchange_pinning_enabled_uses_bundled_certs(self, requests_mock): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'id_token': 'fake'} + requests_mock.return_value = mock_response + c = client.Client(CLIENT_ID, CLIENT_SECRET, HOST, REDIRECT_URI) + try: + c.exchange_authorization_code_for_2fa_result('code', 'user') + except client.DuoException: + pass + _, kwargs = requests_mock.call_args + self.assertEqual(kwargs['verify'], client.DEFAULT_CA_CERT_PATH) + + if __name__ == '__main__': unittest.main()