Skip to content

Commit 5513d50

Browse files
dstnluongcopybara-github
authored andcommitted
Publish timesfm serving docker code.
MG_DOCKER_CODES_PIPER_ORIGIN_REV_ID: 776632214
1 parent 85fa955 commit 5513d50

3 files changed

Lines changed: 588 additions & 0 deletions

File tree

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
FROM nvidia/cuda:12.3.2-devel-ubuntu22.04
2+
3+
# Install basic libs
4+
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \
5+
cmake \
6+
curl \
7+
wget \
8+
sudo \
9+
gnupg \
10+
libsm6 \
11+
libxext6 \
12+
libxrender-dev \
13+
lsb-release \
14+
ca-certificates \
15+
build-essential \
16+
git \
17+
software-properties-common \
18+
cuda-toolkit \
19+
libcudnn8 \
20+
apt-transport-https
21+
22+
RUN apt install -y --no-install-recommends python3.10 \
23+
python3.10-venv \
24+
python3.10-dev \
25+
python3-pip
26+
27+
Run apt-get autoremove -y
28+
29+
RUN pip install --upgrade pip
30+
RUN pip install --upgrade --ignore-installed \
31+
"jax[cuda12]==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
32+
numpy==1.26.4 \
33+
paxml==1.4.0 \
34+
praxis==1.4.0 \
35+
jaxlib==0.4.26 \
36+
pandas==2.1.4 \
37+
einshape==1.0.0 \
38+
utilsforecast==0.1.10 \
39+
huggingface_hub[cli]==0.23.0 \
40+
google-cloud-aiplatform[prediction]==1.51.0 \
41+
fastapi==0.109.1 \
42+
flask==3.0.3 \
43+
smart_open[gcs]==7.0.4 \
44+
protobuf==3.19.6 \
45+
scikit-learn==1.0.2 \
46+
timesfm==1.0.1
47+
48+
# Download license.
49+
RUN wget https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/LICENSE
50+
51+
# Move scaffold.
52+
COPY model_oss/timesfm/main.py /app/main.py
53+
COPY model_oss/timesfm/predictor.py /app/predictor.py
54+
55+
WORKDIR ..
56+
57+
# Spin off inference server.
58+
CMD ["python3", "/app/main.py"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Predict server for TimesFM."""
2+
3+
import json
4+
import os
5+
import flask
6+
import predictor
7+
8+
# Create the flask app.
9+
app = flask.Flask(__name__)
10+
_OK_STATUS = 200
11+
_INTERNAL_ERROR_STATUS = 500
12+
_HOST = '0.0.0.0'
13+
14+
# Define the predictor and load the checkpoints.
15+
predictor = predictor.TimesFMPredictor()
16+
predictor.load(os.environ['AIP_STORAGE_URI'])
17+
18+
19+
@app.route(os.environ['AIP_HEALTH_ROUTE'], methods=['GET'])
20+
def health() -> flask.Response:
21+
return flask.Response(status=_OK_STATUS)
22+
23+
24+
@app.route(os.environ['AIP_PREDICT_ROUTE'], methods=['GET', 'POST'])
25+
def predict() -> flask.Response:
26+
"""Calls TimesFM for prediction.
27+
28+
Returns:
29+
A `flask.Response` containing the prediction result in JSON.
30+
"""
31+
try:
32+
body = flask.request.get_json(silent=True, force=True)
33+
preprocessed_inputs = predictor.preprocess(body)
34+
outputs = predictor.predict(preprocessed_inputs)
35+
postprocessed_outputs = predictor.postprocess(outputs)
36+
return flask.Response(
37+
json.dumps(postprocessed_outputs),
38+
status=_OK_STATUS,
39+
mimetype='application/json',
40+
)
41+
except Exception as e: # pylint: disable=broad-exception-caught
42+
return flask.Response(
43+
json.dumps({'error': str(e)}),
44+
status=_INTERNAL_ERROR_STATUS,
45+
mimetype='application/json',
46+
)
47+
48+
49+
if __name__ == '__main__':
50+
app.run(host=_HOST, port=os.environ['AIP_HTTP_PORT'])

0 commit comments

Comments
 (0)