@@ -3,82 +3,56 @@ import {
33 InternalError ,
44 createFilterExtension ,
55} from '@vulcan-sql/core' ;
6- import axios , { AxiosError } from 'axios' ;
7- import { convertToHuggingFaceTable } from '../utils' ;
8- import { isArray } from 'class-validator' ;
9- import { has } from 'lodash' ;
106
11- type HuggingFaceOptions = {
12- accessToken : string ;
13- } ;
7+ import { convertToHuggingFaceTable , postRequest } from '../utils' ;
8+ import { has , isArray , isEmpty , omit } from 'lodash' ;
9+ import {
10+ InferenceNLPOptions ,
11+ HuggingFaceOptions ,
12+ apiInferenceEndpoint ,
13+ } from '../model' ;
1414
15- // More information described the options, see : https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
15+ // More information described the options. See : https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
1616type TableQuestionAnsweringOptions = {
1717 inputs : {
1818 query : string ;
1919 table : Record < string , string [ ] > ;
2020 } ;
21- options : {
22- use_cache : boolean ;
23- wait_for_model : boolean ;
24- } ;
25- } ;
26-
27- const request = async ( url : string , data : any , token : string ) => {
28- try {
29- const result = await axios . post ( url , data , {
30- headers : { Authorization : `Bearer ${ token } ` } ,
31- } ) ;
32- return result . data ;
33- } catch ( error ) {
34- const axiosError = error as AxiosError ;
35- // https://axios-http.com/docs/handling_errors
36- // if response has error, throw the response error, or throw the request error
37- if ( axiosError . response )
38- throw new Error ( JSON . stringify ( axiosError . response ?. data ) ) ;
39- throw new Error ( axiosError . message ) ;
40- }
21+ options ?: InferenceNLPOptions ;
4122} ;
4223
43- // default recommended model, see https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
24+ /**
25+ * Get table question answering url. Used recommend model be default value.
26+ * See: https://huggingface.co/docs/api-inference/detailed_parameters#table-question-answering-task
27+ * */
4428const getUrl = ( model = 'google/tapas-base-finetuned-wtq' ) =>
45- `https://api-inference.huggingface.co/models /${ model } ` ;
29+ `${ apiInferenceEndpoint } /${ model } ` ;
4630
4731export const TableQuestionAnsweringFilter : FunctionalFilter = async ( {
4832 args,
4933 value,
5034 options,
5135} ) => {
52- if ( ! options || ! ( options as HuggingFaceOptions ) . accessToken )
53- throw new InternalError ( 'please given access token' ) ;
36+ const token = ( options as HuggingFaceOptions ) ? .accessToken ;
37+ if ( ! token ) throw new InternalError ( 'please given access token' ) ;
5438
5539 if ( ! isArray ( value ) )
5640 throw new InternalError ( 'Input value must be an array of object' ) ;
57-
5841 if ( ! ( typeof args === 'object' ) || ! has ( args , 'query' ) )
5942 throw new InternalError ( 'Must provide "query" keyword argument' ) ;
60- if ( ! args [ 'query' ] )
61- throw new InternalError ( 'The "query" argument must have value' ) ;
6243
63- const token = ( options as HuggingFaceOptions ) . accessToken ;
6444 // Convert the data result format to table value format
6545 const table = convertToHuggingFaceTable ( value ) ;
66- const context = {
67- inputs : {
68- query : args [ 'query' ] ,
69- table,
70- } ,
71- options : {
72- use_cache : args [ 'use_cache' ] ? args [ 'use_cache' ] : true ,
73- wait_for_model : args [ 'wait_for_model' ] ? args [ 'wait_for_model' ] : false ,
74- } ,
46+ // omit hidden value '__keywords' from args, it generated from nunjucks and not related to HuggingFace.
47+ const { query, model, ...inferenceOptions } = omit ( args , '__keywords' ) ;
48+ const payload = {
49+ inputs : { query, table } ,
7550 } as TableQuestionAnsweringOptions ;
76-
77- // Get table question answering url
78- const url = args [ 'model' ] ? getUrl ( args [ 'model' ] ) : getUrl ( ) ;
51+ if ( ! isEmpty ( inferenceOptions ) ) payload . options = inferenceOptions ;
7952
8053 try {
81- const results = await request ( url , context , token ) ;
54+ const url = getUrl ( model ) ;
55+ const results = await postRequest ( url , payload , token ) ;
8256 // convert to JSON string to make user get the whole result after parsing it in SQL
8357 return JSON . stringify ( results ) ;
8458 } catch ( error ) {
0 commit comments