Skip to content

Commit 5e22e51

Browse files
committed
chore(extension-huggingface): add logistic of check checking "query" has value with test cases for huggingface filter
- fix grammar in README. - fix the section of document . - add logistic for checking query has value with test cases
1 parent e192990 commit 5e22e51

7 files changed

Lines changed: 139 additions & 26 deletions

File tree

packages/doc/docs/extensions/huggingface/huggingface-text-generation.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ The [Text Generation](https://huggingface.co/docs/api-inference/detailed_paramet
66

77
The result will be a string from `huggingface_text_generation`.
88

9-
:::📢 Notice
10-
The **Text Generation** default model is **gpt2**, If you would like to use the [Meta LLama2](https://huggingface.co/meta-llama) models, you have two method to do:
9+
:::info
10+
The **Text Generation** default model is **gpt2**, If you would like to use the [Meta LLama2](https://huggingface.co/meta-llama) models, you have two methods to do:
1111

1212
1. Subscribe to the [Pro Account](https://huggingface.co/pricing#pro).
1313
- Set the Meta LLama2 model using the `model` keyword argument in `huggingface_text_generation`, e.g: `meta-llama/Llama-2-13b-chat-hf`.

packages/extension-huggingface/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ The [Text Generation](https://huggingface.co/docs/api-inference/detailed_paramet
112112

113113
Using the `huggingface_text_generation` filter. The result will be a string from `huggingface_text_generation`.
114114

115-
**📢 Notice**: The **Text Generation** default model is **gpt2**, If you would like to use the [Meta LLama2](https://huggingface.co/meta-llama) models, you have two method to do:
115+
**📢 Notice**: The **Text Generation** default model is **gpt2**, If you would like to use the [Meta LLama2](https://huggingface.co/meta-llama) models, you have two methods to do:
116116

117117
1. Subscribe to the [Pro Account](https://huggingface.co/pricing#pro).
118118
- Set the Meta LLama2 model using the `model` keyword argument in `huggingface_text_generation`, e.g: `meta-llama/Llama-2-13b-chat-hf`.

packages/extension-huggingface/src/index.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import {
55
} from './lib/filters/tableQuestionAnswering';
66

77
import {
8-
Builder as TextGenerationFilterBuilder,
9-
Runner as TextGenerationFilterRunner,
8+
Builder as HuggingFaceTextGenerationFilterBuilder,
9+
Runner as HuggingFaceTextGenerationFilterRunner,
1010
} from './lib/filters/textGeneration';
1111

1212
export default [
1313
HuggingFaceTableQuestionAnsweringFilterBuilder,
1414
HuggingFaceTableQuestionAnsweringFilterRunner,
15-
TextGenerationFilterBuilder,
16-
TextGenerationFilterRunner,
15+
HuggingFaceTextGenerationFilterBuilder,
16+
HuggingFaceTextGenerationFilterRunner,
1717
];

packages/extension-huggingface/src/lib/filters/tableQuestionAnswering.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ export const TableQuestionAnsweringFilter: FunctionalFilter = async ({
4040
throw new InternalError('Input value must be an array of object');
4141
if (!(typeof args === 'object') || !has(args, 'query'))
4242
throw new InternalError('Must provide "query" keyword argument');
43+
if (!args['query'])
44+
throw new InternalError('The "query" argument must have value');
4345

4446
// Convert the data result format to table value format
4547
const table = convertToHuggingFaceTable(value);

packages/extension-huggingface/src/lib/filters/textGeneration.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ export const TextGenerationFilter: FunctionalFilter = async ({
5555
throw new InternalError('Input value must be an array of object');
5656
if (!(typeof args === 'object') || !has(args, 'query'))
5757
throw new InternalError('Must provide "query" keyword argument');
58+
if (!args['query'])
59+
throw new InternalError('The "query" argument must have value');
5860

5961
// Convert the data result to JSON string as question context
6062
const context = JSON.stringify(value);

packages/extension-huggingface/test/tableQuestionAnswering.spec.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,58 @@ describe('Test "huggingface_table_question_answering" filter', () => {
3434
50 * 1000
3535
);
3636

37+
it(
38+
'Should throw error when pass the "query" argument but value is undefined',
39+
async () => {
40+
const token = process.env['HF_ACCESS_TOKEN'];
41+
const { compileAndLoad, execute } = await getTestCompiler({
42+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
43+
huggingface: {
44+
accessToken: token,
45+
},
46+
});
47+
48+
const sql = `{% set data = ${JSON.stringify(
49+
repositories
50+
)} %}SELECT {{ data | huggingface_table_question_answering(query=undefined) }}`;
51+
52+
// Act
53+
await compileAndLoad(sql);
54+
55+
// Assert
56+
await expect(execute({})).rejects.toThrow(
57+
'The "query" argument must have value'
58+
);
59+
},
60+
50 * 1000
61+
);
62+
63+
it(
64+
'Should throw error when pass the "query" argument but value is empty string',
65+
async () => {
66+
const token = process.env['HF_ACCESS_TOKEN'];
67+
const { compileAndLoad, execute } = await getTestCompiler({
68+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
69+
huggingface: {
70+
accessToken: token,
71+
},
72+
});
73+
74+
const sql = `{% set data = ${JSON.stringify(
75+
repositories
76+
)} %}SELECT {{ data | huggingface_table_question_answering(query='') }}`;
77+
78+
// Act
79+
await compileAndLoad(sql);
80+
81+
// Assert
82+
await expect(execute({})).rejects.toThrow(
83+
'The "query" argument must have value'
84+
);
85+
},
86+
50 * 1000
87+
);
88+
3789
it('Should throw error when input value not be array of object', async () => {
3890
const token = process.env['HF_ACCESS_TOKEN'];
3991
const { compileAndLoad, execute } = await getTestCompiler({

packages/extension-huggingface/test/textGeneration.spec.ts

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,58 @@ describe('Test "huggingface_text_generation" filter', () => {
3333
50 * 1000
3434
);
3535

36+
it(
37+
'Should throw error when pass the "query" argument but value is undefined',
38+
async () => {
39+
const token = process.env['HF_ACCESS_TOKEN'];
40+
const { compileAndLoad, execute } = await getTestCompiler({
41+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
42+
huggingface: {
43+
accessToken: token,
44+
},
45+
});
46+
47+
const sql = `{% set data = ${JSON.stringify(
48+
repositories
49+
)} %}SELECT {{ data | huggingface_text_generation(query=undefined) }}`;
50+
51+
// Act
52+
await compileAndLoad(sql);
53+
54+
// Assert
55+
await expect(execute({})).rejects.toThrow(
56+
'The "query" argument must have value'
57+
);
58+
},
59+
50 * 1000
60+
);
61+
62+
it(
63+
'Should throw error when pass the "query" argument but value is empty string',
64+
async () => {
65+
const token = process.env['HF_ACCESS_TOKEN'];
66+
const { compileAndLoad, execute } = await getTestCompiler({
67+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
68+
huggingface: {
69+
accessToken: token,
70+
},
71+
});
72+
73+
const sql = `{% set data = ${JSON.stringify(
74+
repositories
75+
)} %}SELECT {{ data | huggingface_text_generation(query='') }}`;
76+
77+
// Act
78+
await compileAndLoad(sql);
79+
80+
// Assert
81+
await expect(execute({})).rejects.toThrow(
82+
'The "query" argument must have value'
83+
);
84+
},
85+
50 * 1000
86+
);
87+
3688
it('Should throw error when input value not be array of object', async () => {
3789
const token = process.env['HF_ACCESS_TOKEN'];
3890
const { compileAndLoad, execute } = await getTestCompiler({
@@ -75,27 +127,27 @@ describe('Test "huggingface_text_generation" filter', () => {
75127
},
76128
50 * 1000
77129
);
78-
130+
79131
it(
80-
'Should not throw when pass the "query" argument by dynamic parameter through HuggingFace recommended model',
132+
'Should not throw when passing the "query" argument by dynamic parameter through HuggingFace default recommended "gpt2" model',
81133
async () => {
82134
const token = process.env['HF_ACCESS_TOKEN'];
83-
const { compileAndLoad, execute } =
84-
await getTestCompiler({
85-
extensions: { huggingface: path.join(__dirname, '..', 'src') },
86-
huggingface: {
87-
accessToken: token,
88-
},
89-
});
135+
const { compileAndLoad, execute } = await getTestCompiler({
136+
extensions: { huggingface: path.join(__dirname, '..', 'src') },
137+
huggingface: {
138+
accessToken: token,
139+
},
140+
});
90141

91142
const sql = `{% set data = ${JSON.stringify(
92143
repositories
93144
)} %}SELECT {{ data | huggingface_text_generation(query=context.params.value, wait_for_model=true, use_cache=false) }}`;
94145

95-
96146
await compileAndLoad(sql);
97147
// Assert
98-
await expect(execute({ value: 'what repository has most stars?' })).resolves.not.toThrow();
148+
await expect(
149+
execute({ value: 'what repository has most stars?' })
150+
).resolves.not.toThrow();
99151
},
100152
100 * 1000
101153
);
@@ -117,16 +169,17 @@ describe('Test "huggingface_text_generation" filter', () => {
117169
repositories
118170
)} %}SELECT {{ data | huggingface_text_generation(query=context.params.value,model="meta-llama/Llama-2-13b-chat-hf", wait_for_model=true, use_cache=false) }}`;
119171

120-
121172
await compileAndLoad(sql);
122173
await execute({ value: 'what repository has most stars?' });
123174

124175
// Assert
125176
const queries = await getExecutedQueries();
126177
const bindings = await getCreatedBinding();
127-
178+
128179
expect(queries[0]).toBe('SELECT $1');
129-
expect(bindings[0].get('$1')).toEqual('Answer: Based on the information provided, the repository with the most stars is "vulcan-sql" with 1000 stars.');
180+
expect(bindings[0].get('$1')).toEqual(
181+
'Answer: Based on the information provided, the repository with the most stars is "vulcan-sql" with 1000 stars.'
182+
);
130183
},
131184
100 * 1000
132185
);
@@ -135,19 +188,23 @@ describe('Test "huggingface_text_generation" filter', () => {
135188
it.skip.each([
136189
{
137190
question: 'what repository has most stars?',
138-
expected: 'Answer: Based on the information provided, the repository with the most stars is "vulcan-sql" with 1000 stars.'
191+
expected:
192+
'Answer: Based on the information provided, the repository with the most stars is "vulcan-sql" with 1000 stars.',
139193
},
140194
{
141195
question: 'what repository has lowest stars?',
142-
expected: 'Answer: Based on the information provided, the repository with the lowest stars is "hello-world" with 0 stars.'
196+
expected:
197+
'Answer: Based on the information provided, the repository with the lowest stars is "hello-world" with 0 stars.',
143198
},
144199
{
145200
question: 'How many stars does the vulcan-sql repository have?',
146-
expected: 'Answer: Based on the information provided, the vulcan-sql repository has 1000 stars.'
201+
expected:
202+
'Answer: Based on the information provided, the vulcan-sql repository has 1000 stars.',
147203
},
148204
{
149205
question: 'How many stars does the accio repository have?',
150-
expected: 'Answer: Based on the information provided, the accio repository has 500 stars.'
206+
expected:
207+
'Answer: Based on the information provided, the accio repository has 500 stars.',
151208
},
152209
{
153210
question: 'How many repositories related to data-lake topic?',
@@ -156,7 +213,7 @@ describe('Test "huggingface_text_generation" filter', () => {
156213
1. vulcan-sql
157214
2. accio
158215
159-
Both of these repositories have the data-lake topic in their description.`
216+
Both of these repositories have the data-lake topic in their description.`,
160217
},
161218
])(
162219
'Should get "$expected" answer when asking "$question" through "meta-llama/Llama-2-13b-chat-hf" model',

0 commit comments

Comments
 (0)