Skip to content

Commit 48c37cb

Browse files
committed
feat(extension-huggingface): support text generation task
- Add the "TextGenerationFilter". - support huggingface filters could pass "endpoint" keyword arguments when using different filter task. - add test cases of "TextGenerationFilter".
1 parent 75b8625 commit 48c37cb

5 files changed

Lines changed: 354 additions & 3 deletions

File tree

packages/extension-huggingface/README.md

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ VulcanSQL support using Hugging Face tasks by [VulcanSQL Filters](https://vulcan
2727
2828
**⚠️ Caution**: Hugging Face has a [rate limit](https://huggingface.co/docs/api-inference/faq#rate-limits), so it does not allow sending large datasets to the Hugging Face library for processing. Otherwise, using a different Hugging Face model may yield different results or even result in failure.
2929
30+
3031
### Table Question Answering
3132
3233
The [Table Question Answering](https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task) is one of the Natural Language Processing tasks supported by Hugging Face.
@@ -50,7 +51,7 @@ The result will be converted to a JSON string from `huggingface_table_question_a
5051
"description": "Query Your Data Warehouse Like Exploring One Big View."
5152
},
5253
{
53-
"repository": "hell-word",
54+
"repository": "hello-world",
5455
"topic": [],
5556
"description": "Sample repository for testing"
5657
}
@@ -80,6 +81,7 @@ SELECT {{ data | huggingface_table_question_answering(query="How many repositori
8081
{% set question = "List display name where gender are female?" %}
8182
8283
-- The "model" keyword argument is optional. If not provided, the default value is 'google/tapas-base-finetuned-wtq'.
84+
-- The "endpoint" keyword argument is optional. If was provided, use the endpoint instead of the huggingface official URL with provided model if you also gave.
8385
-- The "wait_for_model" keyword argument is optional. If not provided, the default value is false.
8486
-- The "use_cache" keyword argument is optional. If not provided, the default value is true.
8587
SELECT {{ products.value() | huggingface_table_question_answering(query=question, model="microsoft/tapex-base-finetuned-wtq", wait_for_model=true, use_cache=true) }}
@@ -94,3 +96,60 @@ SELECT {{ products.value() | huggingface_table_question_answering(query=question
9496
}
9597
]
9698
```
99+
100+
### Text Generation
101+
102+
The [Text Generation](https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task) is one of the Natural Language Processing tasks supported by Hugging Face.
103+
104+
Using the `huggingface_text_generation` filter. The result will be a string from `huggingface_text_generation`.
105+
106+
**Notice**: The **Text Generation** default model is **gpt2**, If you would like to use the [Meta LLama2](https://huggingface.co/meta-llama) models, you have two method to do:
107+
108+
1. Subscribe to the [Pro Account](https://huggingface.co/pricing#pro). Set the Meta LLama2 model using the `model` keyword argument in `huggingface_text_generation`, e.g: `meta-llama/Llama-2-13b-chat-hf`.
109+
110+
**SQL**:
111+
112+
```sql
113+
{% set data = [
114+
{
115+
"rank": 1,
116+
"institution": "Massachusetts Institute of Technology (MIT)",
117+
"location code":"US",
118+
"location":"United States"
119+
},
120+
{
121+
"rank": 2,
122+
"institution": "University of Cambridge",
123+
"location code":"UK",
124+
"location":"United Kingdom"
125+
},
126+
{
127+
"rank": 3,
128+
"institution": "Stanford University"
129+
"location code":"US",
130+
"location":"United States"
131+
}
132+
-- other universities.....
133+
] %}
134+
135+
SELECT {{ data | huggingface_text_generation(query="Which university is the top-ranked university?", model="meta-llama/Llama-2-13b-chat-hf") }} as result
136+
```
137+
138+
**Response**:
139+
140+
```json
141+
[
142+
{
143+
"result": "Answer: Based on the provided list, the top-ranked university is Massachusetts Institute of Technology (MIT) with a rank of 1."
144+
}
145+
]
146+
```
147+
148+
2. Select one of the [Meta LLama2](https://huggingface.co/meta-llama) Models and deploy it to the [Inference Endpoint](https://huggingface.co/inference-endpoints). Set the endpoint URL using the `endpoint` keyword argument in `huggingface_text_generation`.
149+
150+
```sql
151+
SELECT {{ data | huggingface_text_generation(query="Which university is the top-ranked university?", endpoint='xxx.yyy.zzz.huggingface.cloud') }} as result
152+
```
153+
154+
155+

packages/extension-huggingface/src/index.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@ import {
44
Runner as HuggingFaceTableQuestionAnsweringFilterRunner,
55
} from './lib/filters/tableQuestionAnswering';
66

7+
import {
8+
Builder as TextGenerationFilterBuilder,
9+
Runner as TextGenerationFilterRunner,
10+
} from './lib/filters/textGeneration';
11+
712
export default [
813
HuggingFaceTableQuestionAnsweringFilterBuilder,
914
HuggingFaceTableQuestionAnsweringFilterRunner,
15+
TextGenerationFilterBuilder,
16+
TextGenerationFilterRunner,
1017
];

packages/extension-huggingface/src/lib/filters/tableQuestionAnswering.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ export const TableQuestionAnsweringFilter: FunctionalFilter = async ({
4444
// Convert the data result format to table value format
4545
const table = convertToHuggingFaceTable(value);
4646
// omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace.
47-
const { query, model, ...inferenceOptions } = omit(args, '__keywords');
47+
const { query, model, endpoint, ...inferenceOptions } = omit(args, '__keywords');
4848
const payload = {
4949
inputs: { query, table },
5050
} as TableQuestionAnsweringOptions;
5151
if (!isEmpty(inferenceOptions)) payload.options = inferenceOptions;
5252

5353
try {
54-
const url = getUrl(model);
54+
const url = endpoint ? endpoint : getUrl(model);
5555
const results = await postRequest(url, payload, token);
5656
// convert to JSON string to make user get the whole result after parsing it in SQL
5757
return JSON.stringify(results);
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import {
2+
FunctionalFilter,
3+
InternalError,
4+
createFilterExtension,
5+
} from '@vulcan-sql/core';
6+
import { has, isArray, isEmpty, omit, pick } from 'lodash';
7+
import {
8+
HuggingFaceOptions,
9+
InferenceNLPOptions,
10+
apiInferenceEndpoint,
11+
} from '../model';
12+
import { postRequest } from '../utils';
13+
14+
// More information described the options. See: https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
15+
type TextGenerationOptions = {
16+
inputs: string;
17+
parameters?: {
18+
// Integer to define the top tokens considered within the sample operation to create new text.
19+
top_k?: number;
20+
// Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
21+
top_p?: number;
22+
// Default: 0.1. Range: (0.0 - 100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
23+
temperature?: number;
24+
// Range: (0.0 - 100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes.
25+
repetition_penalty?: number;
26+
// Default: 250. The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated.
27+
max_new_tokens?: number;
28+
// Range (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens for best results.
29+
max_time?: number;
30+
// Default: false. If set to False, the return results will not contain the original query making it easier for prompting.
31+
return_full_text?: boolean;
32+
// Default: 1. The number of proposition you want to be returned.
33+
num_return_sequences?: number;
34+
// Whether or not to use sampling, use greedy decoding otherwise.
35+
do_sample?: boolean;
36+
};
37+
options?: InferenceNLPOptions;
38+
};
39+
40+
/**
41+
* Get text generation url. Used meta Llama 2.0 model be default value.
42+
* See: https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
43+
* */
44+
const getUrl = (model = 'gpt2') => `${apiInferenceEndpoint}/${model}`;
45+
46+
export const TextGenerationFilter: FunctionalFilter = async ({
47+
args,
48+
value,
49+
options,
50+
}) => {
51+
const token = (options as HuggingFaceOptions)?.accessToken;
52+
if (!token) throw new InternalError('please given access token');
53+
54+
if (!isArray(value))
55+
throw new InternalError('Input value must be an array of object');
56+
if (!(typeof args === 'object') || !has(args, 'query'))
57+
throw new InternalError('Must provide "query" keyword argument');
58+
59+
// Convert the data result to JSON string as question context
60+
const context = JSON.stringify(value);
61+
// omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace.
62+
const { query, model, endpoint, ...otherArgs } = omit(args, '__keywords');
63+
const inferenceOptions = pick(otherArgs, ['use_cache', 'wait_for_model']);
64+
const parameters = omit(otherArgs, ['use_cache', 'wait_for_model', 'endpoint']);
65+
const payload = {
66+
inputs: `Context: ${context}. Question: ${query}}`,
67+
parameters: {
68+
return_full_text: false,
69+
max_new_tokens: 250,
70+
temperature: 0.1,
71+
}
72+
} as TextGenerationOptions;
73+
if (!isEmpty(parameters)) payload.parameters = parameters;
74+
if (!isEmpty(inferenceOptions)) payload.options = inferenceOptions;
75+
76+
try {
77+
// if not given endpoint, use default HuggingFace inference endpoint
78+
const url = endpoint ? endpoint : getUrl(model);
79+
const results = await postRequest(url, payload, token);
80+
// get the "generated_text" field, and trim leading and trailing white space.
81+
return String(results[0]['generated_text']).trim();
82+
} catch (error) {
83+
throw new InternalError(
84+
`Error when sending data to Hugging Face for executing TextGeneration tasks, details: ${
85+
(error as Error).message
86+
}`
87+
);
88+
}
89+
};
90+
91+
export const [Builder, Runner] = createFilterExtension(
92+
'huggingface_text_generation',
93+
TextGenerationFilter
94+
);

0 commit comments

Comments
 (0)