Skip to content

Commit 75b8625

Browse files
committed
feat(extension-huggingface): refactor for reusing common feature and refactor huggingface filter logistic
1 parent 180f8a6 commit 75b8625

6 files changed

Lines changed: 295 additions & 272 deletions

File tree

packages/extension-huggingface/src/lib/filters/tableQuestionAnswering.ts

Lines changed: 23 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,82 +3,56 @@ import {
33
InternalError,
44
createFilterExtension,
55
} from '@vulcan-sql/core';
6-
import axios, { AxiosError } from 'axios';
7-
import { convertToHuggingFaceTable } from '../utils';
8-
import { isArray } from 'class-validator';
9-
import { has } from 'lodash';
106

11-
type HuggingFaceOptions = {
12-
accessToken: string;
13-
};
7+
import { convertToHuggingFaceTable, postRequest } from '../utils';
8+
import { has, isArray, isEmpty, omit } from 'lodash';
9+
import {
10+
InferenceNLPOptions,
11+
HuggingFaceOptions,
12+
apiInferenceEndpoint,
13+
} from '../model';
1414

15-
// More information described the options, see: https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
15+
// More information described the options. See: https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
1616
type TableQuestionAnsweringOptions = {
1717
inputs: {
1818
query: string;
1919
table: Record<string, string[]>;
2020
};
21-
options: {
22-
use_cache: boolean;
23-
wait_for_model: boolean;
24-
};
25-
};
26-
27-
const request = async (url: string, data: any, token: string) => {
28-
try {
29-
const result = await axios.post(url, data, {
30-
headers: { Authorization: `Bearer ${token}` },
31-
});
32-
return result.data;
33-
} catch (error) {
34-
const axiosError = error as AxiosError;
35-
// https://axios-http.com/docs/handling_errors
36-
// if response has error, throw the response error, or throw the request error
37-
if (axiosError.response)
38-
throw new Error(JSON.stringify(axiosError.response?.data));
39-
throw new Error(axiosError.message);
40-
}
21+
options?: InferenceNLPOptions;
4122
};
4223

43-
// default recommended model, see https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
24+
/**
25+
* Get table question answering url. Used recommend model be default value.
26+
* See: https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
27+
* */
4428
const getUrl = (model = 'google/tapas-base-finetuned-wtq') =>
45-
`https://api-inference.huggingface.co/models/${model}`;
29+
`${apiInferenceEndpoint}/${model}`;
4630

4731
export const TableQuestionAnsweringFilter: FunctionalFilter = async ({
4832
args,
4933
value,
5034
options,
5135
}) => {
52-
if (!options || !(options as HuggingFaceOptions).accessToken)
53-
throw new InternalError('please given access token');
36+
const token = (options as HuggingFaceOptions)?.accessToken;
37+
if (!token) throw new InternalError('please given access token');
5438

5539
if (!isArray(value))
5640
throw new InternalError('Input value must be an array of object');
57-
5841
if (!(typeof args === 'object') || !has(args, 'query'))
5942
throw new InternalError('Must provide "query" keyword argument');
60-
if (!args['query'])
61-
throw new InternalError('The "query" argument must have value');
6243

63-
const token = (options as HuggingFaceOptions).accessToken;
6444
// Convert the data result format to table value format
6545
const table = convertToHuggingFaceTable(value);
66-
const context = {
67-
inputs: {
68-
query: args['query'],
69-
table,
70-
},
71-
options: {
72-
use_cache: args['use_cache'] ? args['use_cache'] : true,
73-
wait_for_model: args['wait_for_model'] ? args['wait_for_model'] : false,
74-
},
46+
// omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace.
47+
const { query, model, ...inferenceOptions } = omit(args, '__keywords');
48+
const payload = {
49+
inputs: { query, table },
7550
} as TableQuestionAnsweringOptions;
76-
77-
// Get table question answering url
78-
const url = args['model'] ? getUrl(args['model']) : getUrl();
51+
if (!isEmpty(inferenceOptions)) payload.options = inferenceOptions;
7952

8053
try {
81-
const results = await request(url, context, token);
54+
const url = getUrl(model);
55+
const results = await postRequest(url, payload, token);
8256
// convert to JSON string to make user get the whole result after parsing it in SQL
8357
return JSON.stringify(results);
8458
} catch (error) {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
export type HuggingFaceOptions = {
2+
accessToken: string;
3+
};
4+
5+
export const apiInferenceEndpoint =
6+
'https://api-inference.huggingface.co/models';
7+
8+
// For more information. See: https://huggingface.co/docs/api-inference/detailed_parameters#natural-language-processing
9+
export type InferenceNLPOptions = {
10+
// Default: true. There is a cache layer on the inference API to speedup requests we have already seen. Most models can use those results as is as models are deterministic (meaning the results will be the same anyway).
11+
use_cache?: boolean;
12+
// Default: false. If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your inference done.
13+
wait_for_model?: boolean;
14+
};
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export * from './converter';
2+
export * from './request';
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import axios, { AxiosError } from 'axios';
2+
3+
export const postRequest = async (url: string, data: any, token: string) => {
4+
try {
5+
const result = await axios.post(url, data, {
6+
headers: { Authorization: `Bearer ${token}` },
7+
});
8+
return result.data;
9+
} catch (error) {
10+
const axiosError = error as AxiosError;
11+
// https://axios-http.com/docs/handling_errors
12+
// if response has error, throw the response error, or throw the request error
13+
if (axiosError.response)
14+
throw new Error(JSON.stringify(axiosError.response?.data));
15+
throw new Error(axiosError.message);
16+
}
17+
};

0 commit comments

Comments
 (0)