Skip to content

Commit 9fe5396

Browse files
committed
fix(extension-huggingface): change to use axios to send huggingface api to prevent huggingface.js node version issue.
- remove "huggingface" and "node-fetch" library. - use "axios" to send API request to huggingface. - support passing "wait_for_model" and "use_cache" options when using "huggingface_table_question_answering" filter. - update test cases
1 parent d7956d8 commit 9fe5396

8 files changed

Lines changed: 147 additions & 200 deletions

File tree

packages/extension-huggingface/package.json

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,7 @@
2222
"url": "https://github.com/Canner/vulcan.git"
2323
},
2424
"license": "MIT",
25-
"dependencies": {
26-
"@huggingface/inference": "^2.6.1",
27-
"node-fetch": "2"
28-
},
2925
"peerDependencies": {
3026
"@vulcan-sql/core": "~0.6.0-0"
31-
},
32-
"devDependencies": {
33-
"@types/node-fetch": "^2.6.4"
3427
}
3528
}

packages/extension-huggingface/src/index.ts

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
/**
2-
*
3-
* References:
4-
* https://github.com/jestjs/jest/issues/2071#issuecomment-259709487
5-
* https://nodejs.org/en/blog/release/v16.15.0
6-
* https://stackoverflow.com/questions/48433783/referenceerror-fetch-is-not-defined
7-
*
8-
* In the Node.js 16, fetch API is an experimental function, but huggingface.js need to use fetch API, so import the "node-fetch" package and set it, This is a workaround.
9-
* */
10-
import fetch from 'node-fetch';
11-
globalThis.fetch = fetch as any;
12-
131
// Alias the Builder and Runner from the package and export them to prevent Extension loader loading the same Builder and Runner class when creating multiple Functional Filters or Tags.
142
import {
153
Builder as HuggingFaceTableQuestionAnsweringFilterBuilder,

packages/extension-huggingface/src/lib/tableQuestionAnswering.ts

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,86 @@ import {
33
InternalError,
44
createFilterExtension,
55
} from '@vulcan-sql/core';
6-
7-
import {
8-
HuggingFaceOptions,
9-
createHuggingFaceInference,
10-
convertToHuggingFaceTable,
11-
} from './utils';
6+
import axios, { AxiosError } from 'axios';
7+
import { convertToHuggingFaceTable } from './utils';
128
import { isArray } from 'class-validator';
139
import { has } from 'lodash';
1410

11+
type HuggingFaceOptions = {
12+
accessToken: string;
13+
};
14+
15+
// More information described the options, see: https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
1516
type TableQuestionAnsweringOptions = {
16-
model?: string;
1717
inputs: {
1818
query: string;
1919
table: Record<string, string[]>;
2020
};
21+
options: {
22+
use_cache: boolean;
23+
wait_for_model: boolean;
24+
};
25+
};
26+
27+
const request = async (url: string, data: any, token: string) => {
28+
try {
29+
const result = await axios.post(url, data, {
30+
headers: { Authorization: `Bearer ${token}` },
31+
});
32+
return result.data;
33+
} catch (error) {
34+
const axiosError = error as AxiosError;
35+
// https://axios-http.com/docs/handling_errors
36+
// if response has error, throw the response error, or throw the request error
37+
if (axiosError.response)
38+
throw new Error(JSON.stringify(axiosError.response?.data));
39+
throw new Error(axiosError.message);
40+
}
2141
};
2242

43+
// default recommended model, see https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
44+
const getUrl = (model = 'google/tapas-base-finetuned-wtq') =>
45+
`https://api-inference.huggingface.co/models/${model}`;
46+
2347
export const TableQuestionAnsweringFilter: FunctionalFilter = async ({
2448
args,
2549
value,
2650
options,
2751
}) => {
28-
const hf = createHuggingFaceInference(options as HuggingFaceOptions);
52+
if (!options || !(options as HuggingFaceOptions).accessToken)
53+
throw new InternalError('please given access token');
2954

3055
if (!isArray(value))
3156
throw new InternalError('Input value must be an array of object');
3257

3358
if (!(typeof args === 'object') || !has(args, 'query'))
3459
throw new InternalError('Must provide "query" keyword argument');
35-
3660
if (!args['query'])
3761
throw new InternalError('The "query" argument must have value');
3862

63+
const token = (options as HuggingFaceOptions).accessToken;
3964
// Convert the data result format to table value format
4065
const table = convertToHuggingFaceTable(value);
4166
const context = {
42-
// default recommended model, see https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
43-
model: 'google/tapas-base-finetuned-wtq',
4467
inputs: {
4568
query: args['query'],
4669
table,
4770
},
71+
options: {
72+
use_cache: args['use_cache'] ? args['use_cache'] : true,
73+
wait_for_model: args['wait_for_model'] ? args['wait_for_model'] : false,
74+
},
4875
} as TableQuestionAnsweringOptions;
49-
// if model is specified, assign the model provided
50-
if (args['model']) context.model = args['model'];
76+
77+
// Get table question answering url
78+
const url = args['model'] ? getUrl(args['model']) : getUrl();
5179

5280
try {
53-
const results = await hf.tableQuestionAnswering(context);
81+
const results = await request(url, context, token);
5482
// result format, convert to suitable FunctionalFilter response => https://huggingface.co/docs/api-inference/detailed_parameters#question-answering-task
55-
if (results.aggregator === 'NONE') return results.answer;
83+
if (!results.aggregator || results.aggregator === 'NONE')
84+
// trim the beginning & ending space if model returned answer exist the space, e.g: ' hello world'
85+
return (results.answer as string).trim();
5686
return results.cells.join(', ');
5787
} catch (error) {
5888
throw new InternalError(

packages/extension-huggingface/src/lib/utils/factory.ts

Lines changed: 0 additions & 15 deletions
This file was deleted.
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
export * from './factory';
21
export * from './converter';

packages/extension-huggingface/test/factory.spec.ts

Lines changed: 0 additions & 16 deletions
This file was deleted.

packages/extension-huggingface/test/tableQuestionAnswering.spec.ts

Lines changed: 102 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import faker from '@faker-js/faker';
12
import { getTestCompiler } from '@vulcan-sql/test-utility';
23
import * as dotenv from 'dotenv';
34
import * as path from 'path';
@@ -32,53 +33,6 @@ const data = [
3233
},
3334
];
3435

35-
it.each([
36-
{ question: 'what repository has most stars?', expected: 'vulcan-sql' },
37-
{ question: 'what repository has lowest stars?', expected: 'hell-word' },
38-
{
39-
question: 'How many stars does the vulcan-sql repository have?',
40-
expected: '1000',
41-
},
42-
{
43-
question: 'How many stars does the accio repository have?',
44-
expected: '500',
45-
},
46-
{
47-
question: 'How many repositories related to data-lake topic?',
48-
expected: 'vulcan-sql, accio',
49-
},
50-
])(
51-
'Should get correct expected $answer when asking $question',
52-
async ({ question, expected }) => {
53-
// Arrange
54-
55-
const token = process.env['HF_ACCESS_TOKEN'];
56-
const { compileAndLoad, execute, getExecutedQueries, getCreatedBinding } =
57-
await getTestCompiler({
58-
extensions: { huggingface: path.join(__dirname, '..', 'src') },
59-
huggingface: {
60-
accessToken: token,
61-
},
62-
});
63-
64-
const sql = `{% set data = ${JSON.stringify(
65-
data
66-
)} %}SELECT {{ data | huggingface_table_question_answering(query="${question}") }}`;
67-
68-
// Act
69-
await compileAndLoad(sql);
70-
await execute({});
71-
72-
// Assert
73-
const queries = await getExecutedQueries();
74-
const bindings = await getCreatedBinding();
75-
76-
expect(queries[0]).toBe('SELECT $1');
77-
expect(bindings[0].get('$1')).toEqual(expected);
78-
},
79-
50 * 1000
80-
);
81-
8236
it(
8337
'Should throw error when not pass the "query" argument',
8438
async () => {
@@ -152,27 +106,121 @@ it('Should throw error when input value not be array of object', async () => {
152106
});
153107

154108
it(
155-
'Should throw error when provided model cause the Hugging Face tableQuestionAnswering task failed',
109+
'Should throw error when not provide access token',
156110
async () => {
157-
const token = process.env['HF_ACCESS_TOKEN'];
158111
const { compileAndLoad, execute } = await getTestCompiler({
159112
extensions: { huggingface: path.join(__dirname, '..', 'src') },
160113
huggingface: {
161-
accessToken: token,
114+
accessToken: '',
162115
},
163116
});
164117

165118
const sql = `{% set data = ${JSON.stringify(
166119
data
167-
)} %}SELECT {{ data | huggingface_table_question_answering(query="what repository has most stars?", model="neulab/omnitab-large-1024shot-finetuned-wtq-1024shot") }}`;
120+
)} %}SELECT {{ data | huggingface_table_question_answering("${faker.internet.password()}") }}`;
168121

169122
// Act
170123
await compileAndLoad(sql);
171124

172125
// Assert
173-
await expect(execute({})).rejects.toThrow(
174-
"Error when sending data to Hugging Face for executing TableQuestionAnswering tasks, details: Invalid inference output: Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}. Use the 'request' method with the same parameters to do a custom call with no type checking."
175-
);
126+
await expect(execute({})).rejects.toThrow('please given access token');
127+
},
128+
50 * 1000
129+
);
130+
131+
it(
132+
'Should throw error when not set hugging face options',
133+
async () => {
134+
const { compileAndLoad, execute } = await getTestCompiler({
135+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
136+
});
137+
138+
const sql = `{% set data = ${JSON.stringify(
139+
data
140+
)} %}SELECT {{ data | huggingface_table_question_answering("${faker.internet.password()}") }}`;
141+
142+
// Act
143+
await compileAndLoad(sql);
144+
145+
// Assert
146+
await expect(execute({})).rejects.toThrow('please given access token');
147+
},
148+
50 * 1000
149+
);
150+
151+
it(
152+
'Should get correct expected value when provided "neulab/omnitab-large-1024shot-finetuned-wtq-1024shot" model and wait it for model',
153+
async () => {
154+
const token = process.env['HF_ACCESS_TOKEN'];
155+
const { compileAndLoad, execute, getExecutedQueries, getCreatedBinding } =
156+
await getTestCompiler({
157+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
158+
huggingface: {
159+
accessToken: token,
160+
},
161+
});
162+
163+
const sql = `{% set data = ${JSON.stringify(
164+
data
165+
)} %}SELECT {{ data | huggingface_table_question_answering(query="what repository has most stars?", model="neulab/omnitab-large-1024shot-finetuned-wtq-1024shot", wait_for_model=true) }}`;
166+
167+
// Act
168+
await compileAndLoad(sql);
169+
await execute({});
170+
171+
// Assert
172+
const queries = await getExecutedQueries();
173+
const bindings = await getCreatedBinding();
174+
175+
expect(queries[0]).toBe('SELECT $1');
176+
expect(bindings[0].get('$1')).toEqual('vulcan-sql');
177+
},
178+
50 * 1000
179+
);
180+
181+
it.each([
182+
{ question: 'what repository has most stars?', expected: 'vulcan-sql' },
183+
{ question: 'what repository has lowest stars?', expected: 'hell-word' },
184+
{
185+
question: 'How many stars does the vulcan-sql repository have?',
186+
expected: '1000',
187+
},
188+
{
189+
question: 'How many stars does the accio repository have?',
190+
expected: '500',
191+
},
192+
{
193+
question: 'How many repositories related to data-lake topic?',
194+
expected: 'vulcan-sql, accio',
195+
},
196+
])(
197+
'Should get correct expected $answer when asking $question',
198+
async ({ question, expected }) => {
199+
// Arrange
200+
201+
const token = process.env['HF_ACCESS_TOKEN'];
202+
const { compileAndLoad, execute, getExecutedQueries, getCreatedBinding } =
203+
await getTestCompiler({
204+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
205+
huggingface: {
206+
accessToken: token,
207+
},
208+
});
209+
210+
const sql = `{% set data = ${JSON.stringify(
211+
data
212+
)} %}SELECT {{ data | huggingface_table_question_answering(query="${question}", wait_for_model=true) }}`;
213+
214+
// Act
215+
await compileAndLoad(sql);
216+
await execute({});
217+
218+
// Assert
219+
const queries = await getExecutedQueries();
220+
const bindings = await getCreatedBinding();
221+
222+
expect(queries[0]).toBe('SELECT $1');
223+
expect(bindings[0].get('$1')).toEqual(expected);
176224
},
177225
50 * 1000
178226
);

0 commit comments

Comments
 (0)