|
| 1 | +import { |
| 2 | + FunctionalFilter, |
| 3 | + InternalError, |
| 4 | + createFilterExtension, |
| 5 | +} from '@vulcan-sql/core'; |
| 6 | +import { has, isArray, isEmpty, omit, pick } from 'lodash'; |
| 7 | +import { |
| 8 | + HuggingFaceOptions, |
| 9 | + InferenceNLPOptions, |
| 10 | + apiInferenceEndpoint, |
| 11 | +} from '../model'; |
| 12 | +import { postRequest } from '../utils'; |
| 13 | + |
| 14 | +// More information described the options. See: https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task |
| 15 | +type TextGenerationOptions = { |
| 16 | + inputs: string; |
| 17 | + parameters?: { |
| 18 | + // Integer to define the top tokens considered within the sample operation to create new text. |
| 19 | + top_k?: number; |
| 20 | + // Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p. |
| 21 | + top_p?: number; |
| 22 | + // Default: 0.1. Range: (0.0 - 100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability. |
| 23 | + temperature?: number; |
| 24 | + // Range: (0.0 - 100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes. |
| 25 | + repetition_penalty?: number; |
| 26 | + // Default: 250. The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated. |
| 27 | + max_new_tokens?: number; |
| 28 | + // Range (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens for best results. |
| 29 | + max_time?: number; |
| 30 | + // Default: false. If set to False, the return results will not contain the original query making it easier for prompting. |
| 31 | + return_full_text?: boolean; |
| 32 | + // Default: 1. The number of proposition you want to be returned. |
| 33 | + num_return_sequences?: number; |
| 34 | + // Whether or not to use sampling, use greedy decoding otherwise. |
| 35 | + do_sample?: boolean; |
| 36 | + }; |
| 37 | + options?: InferenceNLPOptions; |
| 38 | +}; |
| 39 | + |
| 40 | +/** |
| 41 | + * Get text generation url. Used meta Llama 2.0 model be default value. |
| 42 | + * See: https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task |
| 43 | + * */ |
| 44 | +const getUrl = (model = 'gpt2') => `${apiInferenceEndpoint}/${model}`; |
| 45 | + |
| 46 | +export const TextGenerationFilter: FunctionalFilter = async ({ |
| 47 | + args, |
| 48 | + value, |
| 49 | + options, |
| 50 | +}) => { |
| 51 | + const token = (options as HuggingFaceOptions)?.accessToken; |
| 52 | + if (!token) throw new InternalError('please given access token'); |
| 53 | + |
| 54 | + if (!isArray(value)) |
| 55 | + throw new InternalError('Input value must be an array of object'); |
| 56 | + if (!(typeof args === 'object') || !has(args, 'query')) |
| 57 | + throw new InternalError('Must provide "query" keyword argument'); |
| 58 | + |
| 59 | + // Convert the data result to JSON string as question context |
| 60 | + const context = JSON.stringify(value); |
| 61 | + // omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace. |
| 62 | + const { query, model, endpoint, ...otherArgs } = omit(args, '__keywords'); |
| 63 | + const inferenceOptions = pick(otherArgs, ['use_cache', 'wait_for_model']); |
| 64 | + const parameters = omit(otherArgs, ['use_cache', 'wait_for_model', 'endpoint']); |
| 65 | + const payload = { |
| 66 | + inputs: `Context: ${context}. Question: ${query}}`, |
| 67 | + parameters: { |
| 68 | + return_full_text: false, |
| 69 | + max_new_tokens: 250, |
| 70 | + temperature: 0.1, |
| 71 | + } |
| 72 | + } as TextGenerationOptions; |
| 73 | + if (!isEmpty(parameters)) payload.parameters = parameters; |
| 74 | + if (!isEmpty(inferenceOptions)) payload.options = inferenceOptions; |
| 75 | + |
| 76 | + try { |
| 77 | + // if not given endpoint, use default HuggingFace inference endpoint |
| 78 | + const url = endpoint ? endpoint : getUrl(model); |
| 79 | + const results = await postRequest(url, payload, token); |
| 80 | + // get the "generated_text" field, and trim leading and trailing white space. |
| 81 | + return String(results[0]['generated_text']).trim(); |
| 82 | + } catch (error) { |
| 83 | + throw new InternalError( |
| 84 | + `Error when sending data to Hugging Face for executing TextGeneration tasks, details: ${ |
| 85 | + (error as Error).message |
| 86 | + }` |
| 87 | + ); |
| 88 | + } |
| 89 | +}; |
| 90 | + |
| 91 | +export const [Builder, Runner] = createFilterExtension( |
| 92 | + 'huggingface_text_generation', |
| 93 | + TextGenerationFilter |
| 94 | +); |
0 commit comments