diff --git a/src/mistralai/client/_hooks/workflow_encoding_hook.py b/src/mistralai/client/_hooks/workflow_encoding_hook.py index d65c3ff4..b33598f5 100644 --- a/src/mistralai/client/_hooks/workflow_encoding_hook.py +++ b/src/mistralai/client/_hooks/workflow_encoding_hook.py @@ -24,6 +24,14 @@ logger = logging.getLogger(__name__) + +def _strip_content_encoding_header(headers: httpx.Headers) -> httpx.Headers: + """Remove Content-Encoding header since content is already decompressed by httpx.""" + return httpx.Headers( + [(k, v) for k, v in headers.items() if k.lower() != "content-encoding"] + ) + + # Attribute name for storing config ID on SDKConfiguration _ENCODING_CONFIG_ID_ATTR = "_workflow_encoding_config_id" @@ -277,10 +285,9 @@ def _wrap_sse_response_with_decryption( decrypting_stream = _DecryptingAsyncByteStream(original_stream, payload_encoder) # Create new response with wrapped stream - # Use internal _content to avoid reading stream new_response = httpx.Response( status_code=response.status_code, - headers=response.headers, + headers=_strip_content_encoding_header(response.headers), stream=decrypting_stream, request=response.request, extensions=response.extensions, @@ -408,9 +415,7 @@ def after_success( result = body.get("result") if ( result is not None - and encoding_config.payload_encoder.check_is_payload_encoded( - result - ) + and encoding_config.payload_encoder.check_is_payload_encoded(result) ): decoded_result = _run_async( encoding_config.payload_encoder.decode_network_result(result) @@ -421,7 +426,7 @@ def after_success( response = httpx.Response( status_code=response.status_code, - headers=response.headers, + headers=_strip_content_encoding_header(response.headers), content=new_content, request=response.request, extensions=response.extensions, @@ -441,7 +446,7 @@ def after_success( response = httpx.Response( status_code=response.status_code, - headers=response.headers, + headers=_strip_content_encoding_header(response.headers), content=new_content, request=response.request, extensions=response.extensions, diff --git a/src/mistralai/extra/tests/test_workflow_encoding.py b/src/mistralai/extra/tests/test_workflow_encoding.py index 8aa4f2b1..ad7ff192 100644 --- a/src/mistralai/extra/tests/test_workflow_encoding.py +++ b/src/mistralai/extra/tests/test_workflow_encoding.py @@ -682,3 +682,81 @@ async def test_payload_encoder_encodes_event_content_without_offloading(): assert encoding_options == [EncodedPayloadOptions.COMPRESSED] assert decoded == payload + + +@pytest.mark.asyncio +async def test_workflow_encoding_hook_handles_gzipped_response(): + """Test that WorkflowEncodingHook correctly handles gzipped responses. + + When httpx receives a gzip-compressed response, it auto-decompresses the content + but preserves the Content-Encoding header. If we create a new Response with this + header but with already-decompressed content, httpx will try to decompress again, + causing a zlib error. The fix strips Content-Encoding when creating new Responses. + """ + import gzip + import httpx + from pydantic import SecretStr + from mistralai.client import Mistral + from mistralai.client._hooks.workflow_encoding_hook import ( + WorkflowEncodingHook, + configure_workflow_encoding, + EXECUTE_WORKFLOW_OPERATION_ID, + ) + from mistralai.client._hooks.types import AfterSuccessContext, HookContext + + # Setup client with encryption + client = Mistral(api_key="test-key") + config = WorkflowEncodingConfig( + payload_encryption=PayloadEncryptionConfig( + mode=PayloadEncryptionMode.FULL, + main_key=SecretStr("0" * 64), + ) + ) + configure_workflow_encoding( + config, + namespace="test-namespace", + sdk_config=client.sdk_configuration, + ) + + # Create an encoded result using the encoder + encoder = PayloadEncoder(encoding_config=config) + context = WorkflowContext(namespace="test-namespace", execution_id="test-123") + original_data = {"secret": "value"} + encoded_input = await encoder.encode_network_input(original_data, context) + + # Create gzipped response with encoded result + body = { + "execution_id": "test-exec-123", + "status": "COMPLETED", + "result": encoded_input.model_dump(mode="json"), + } + compressed_body = gzip.compress(json.dumps(body).encode("utf-8")) + mock_request = httpx.Request( + "GET", "https://api.example.com/v1/workflows/executions/test-exec-123" + ) + response = httpx.Response( + status_code=200, + headers={"Content-Type": "application/json", "Content-Encoding": "gzip"}, + content=compressed_body, + request=mock_request, + ) + + # Create hook context + hook_ctx = AfterSuccessContext( + HookContext( + config=client.sdk_configuration, + base_url="https://api.example.com", + operation_id=EXECUTE_WORKFLOW_OPERATION_ID, + oauth2_scopes=[], + security_source=None, + ) + ) + + # Call after_success - without the fix, this raises httpx.DecodingError + hook = WorkflowEncodingHook() + result = hook.after_success(hook_ctx, response) + + # Verify response is valid and result is decoded + assert isinstance(result, httpx.Response) + response_body = json.loads(result.content) + assert response_body["result"] == original_data