Description
For torch_aoti platform, the triton gets wrong number of inputs from AOTIModelPackageLoader, which is always 2 in pytorch 2.11 implementation - the length of returned vector {input_spec, output_spec}.
Triton Information
I'm using Triton server from docker image 26.03-py3.
To Reproduce
I exported a pytorch model (with pytorch 2.11) which has GraphSignatures like this:
Graph signature:
# inputs
p_op1_weight: PARAMETER target='op1.weight'
p_op1_bias: PARAMETER target='op1.bias'
p_op2_weight: PARAMETER target='op2.weight'
p_op2_bias: PARAMETER target='op2.bias'
p_op3_weight: PARAMETER target='op3.weight'
p_op3_bias: PARAMETER target='op3.bias'
x: USER_INPUT
# outputs
linear_1: USER_OUTPUT
And in the config file I put:
max_batch_size : 64
input [
{
name: "input__0"
data_type: TYPE_FP32
dims: [ 500 ]
reshape { shape: [ 1, 500 ] }
}
]
output [
{
name: "output__0"
data_type: TYPE_FP32
dims: [ -1, 128 ]
}
]
And Triton Server failed with error:
E0422 04:21:04.650406 530 backend_model.cc:708] "ERROR: Failed to create instance: failed to create instance \"model1_0_0\" for inductor model \"model1\": Failed to create ModelInstanceState for model \"model1\": Failed to load model \"model1_0_0\" configuration expects 1 inputs, but model expects 2 inputs."
More details on the pytorch model impl:
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.op1 = torch.nn.LayerNorm([500,])
self.op2 = torch.nn.Linear(500, 200)
self.op3 = torch.nn.Linear(200, 128)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x1)
y = self.op3(x2)
return y
Exported through:
import os
import torch
from torch.export import Dim, ShapesCollection
if __name__ == '__main__':
model1 = MyModel()
model1 = model1.cuda()
example_x = torch.randn(16, 500, device=torch.cuda.current_device())
sc = ShapesCollection()
sc[example_x] = {0: Dim("batch_size", min=1, max=64)}
dynamic_shapes = sc.dynamic_shapes(model1, (example_x,))
exported_model1 = torch.export.export(
model1, (example_x,), dynamic_shapes=dynamic_shapes)
torch._inductor.aoti_compile_and_package(
exported_model1,
package_path=os.path.join("./", "model1.pt2"),
inductor_configs={},
)
Expected behavior
I think the number of inputs should be 1 in this scenario. In general, this should equal with the number of inputs from the config file.
Description
For
torch_aotiplatform, the triton gets wrong number of inputs from AOTIModelPackageLoader, which is always 2 in pytorch 2.11 implementation - the length of returned vector{input_spec, output_spec}.Triton Information
I'm using Triton server from docker image
26.03-py3.To Reproduce
I exported a pytorch model (with pytorch 2.11) which has GraphSignatures like this:
And in the config file I put:
And Triton Server failed with error:
More details on the pytorch model impl:
Exported through:
Expected behavior
I think the number of inputs should be 1 in this scenario. In general, this should equal with the number of inputs from the config file.