Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/mistralai/client/_hooks/workflow_encoding_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@

logger = logging.getLogger(__name__)


def _strip_content_encoding_header(headers: httpx.Headers) -> httpx.Headers:
"""Remove Content-Encoding header since content is already decompressed by httpx."""
return httpx.Headers(
[(k, v) for k, v in headers.items() if k.lower() != "content-encoding"]
)


# Attribute name for storing config ID on SDKConfiguration
_ENCODING_CONFIG_ID_ATTR = "_workflow_encoding_config_id"

Expand Down Expand Up @@ -277,10 +285,9 @@ def _wrap_sse_response_with_decryption(
decrypting_stream = _DecryptingAsyncByteStream(original_stream, payload_encoder)

# Create new response with wrapped stream
# Use internal _content to avoid reading stream
new_response = httpx.Response(
status_code=response.status_code,
headers=response.headers,
headers=_strip_content_encoding_header(response.headers),
stream=decrypting_stream,
request=response.request,
extensions=response.extensions,
Expand Down Expand Up @@ -408,9 +415,7 @@ def after_success(
result = body.get("result")
if (
result is not None
and encoding_config.payload_encoder.check_is_payload_encoded(
result
)
and encoding_config.payload_encoder.check_is_payload_encoded(result)
):
decoded_result = _run_async(
encoding_config.payload_encoder.decode_network_result(result)
Expand All @@ -421,7 +426,7 @@ def after_success(

response = httpx.Response(
status_code=response.status_code,
headers=response.headers,
headers=_strip_content_encoding_header(response.headers),
content=new_content,
request=response.request,
extensions=response.extensions,
Expand All @@ -441,7 +446,7 @@ def after_success(

response = httpx.Response(
status_code=response.status_code,
headers=response.headers,
headers=_strip_content_encoding_header(response.headers),
content=new_content,
request=response.request,
extensions=response.extensions,
Expand Down
78 changes: 78 additions & 0 deletions src/mistralai/extra/tests/test_workflow_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,81 @@ async def test_payload_encoder_encodes_event_content_without_offloading():

assert encoding_options == [EncodedPayloadOptions.COMPRESSED]
assert decoded == payload


@pytest.mark.asyncio
async def test_workflow_encoding_hook_handles_gzipped_response():
"""Test that WorkflowEncodingHook correctly handles gzipped responses.

When httpx receives a gzip-compressed response, it auto-decompresses the content
but preserves the Content-Encoding header. If we create a new Response with this
header but with already-decompressed content, httpx will try to decompress again,
causing a zlib error. The fix strips Content-Encoding when creating new Responses.
"""
import gzip
import httpx
from pydantic import SecretStr
from mistralai.client import Mistral
from mistralai.client._hooks.workflow_encoding_hook import (
WorkflowEncodingHook,
configure_workflow_encoding,
EXECUTE_WORKFLOW_OPERATION_ID,
)
from mistralai.client._hooks.types import AfterSuccessContext, HookContext

# Setup client with encryption
client = Mistral(api_key="test-key")
config = WorkflowEncodingConfig(
payload_encryption=PayloadEncryptionConfig(
mode=PayloadEncryptionMode.FULL,
main_key=SecretStr("0" * 64),
)
)
configure_workflow_encoding(
config,
namespace="test-namespace",
sdk_config=client.sdk_configuration,
)

# Create an encoded result using the encoder
encoder = PayloadEncoder(encoding_config=config)
context = WorkflowContext(namespace="test-namespace", execution_id="test-123")
original_data = {"secret": "value"}
encoded_input = await encoder.encode_network_input(original_data, context)

# Create gzipped response with encoded result
body = {
"execution_id": "test-exec-123",
"status": "COMPLETED",
"result": encoded_input.model_dump(mode="json"),
}
compressed_body = gzip.compress(json.dumps(body).encode("utf-8"))
mock_request = httpx.Request(
"GET", "https://api.example.com/v1/workflows/executions/test-exec-123"
)
response = httpx.Response(
status_code=200,
headers={"Content-Type": "application/json", "Content-Encoding": "gzip"},
content=compressed_body,
request=mock_request,
)

# Create hook context
hook_ctx = AfterSuccessContext(
HookContext(
config=client.sdk_configuration,
base_url="https://api.example.com",
operation_id=EXECUTE_WORKFLOW_OPERATION_ID,
oauth2_scopes=[],
security_source=None,
)
)

# Call after_success - without the fix, this raises httpx.DecodingError
hook = WorkflowEncodingHook()
result = hook.after_success(hook_ctx, response)

# Verify response is valid and result is decoded
assert isinstance(result, httpx.Response)
response_body = json.loads(result.content)
assert response_body["result"] == original_data
Loading