@@ -42,7 +42,6 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4242 case VECTOR_TYPE_FLOAT64 :
4343 return dims * sizeof (double );
4444 case VECTOR_TYPE_1BIT :
45- assert ( dims > 0 );
4645 return (dims + 7 ) / 8 ;
4746 default :
4847 assert (0 );
@@ -253,33 +252,84 @@ static int vectorParseSqliteText(
253252 return -1 ;
254253}
255254
255+ static int vectorParseMeta (const unsigned char * pBlob , size_t nBlobSize , int * pType , int * pDims , size_t * pDataSize , char * * pzErrMsg ){
256+ int nLeftoverBits ;
257+
258+ if ( nBlobSize % 2 == 0 ){
259+ * pType = VECTOR_TYPE_FLOAT32 ;
260+ * pDims = nBlobSize / sizeof (float );
261+ * pDataSize = nBlobSize ;
262+ return SQLITE_OK ;
263+ }
264+ * pType = pBlob [nBlobSize - 1 ];
265+ nBlobSize -- ;
266+
267+ if ( * pType == VECTOR_TYPE_FLOAT32 ){
268+ if ( nBlobSize % 4 != 0 ){
269+ * pzErrMsg = sqlite3_mprintf ("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d" , nBlobSize );
270+ return SQLITE_ERROR ;
271+ }
272+ * pDims = nBlobSize / sizeof (float );
273+ * pDataSize = nBlobSize ;
274+ }else if ( * pType == VECTOR_TYPE_FLOAT64 ){
275+ if ( nBlobSize % 8 != 0 ){
276+ * pzErrMsg = sqlite3_mprintf ("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d" , nBlobSize );
277+ return SQLITE_ERROR ;
278+ }
279+ * pDims = nBlobSize / sizeof (double );
280+ * pDataSize = nBlobSize ;
281+ }else if ( * pType == VECTOR_TYPE_1BIT ){
282+ if ( nBlobSize == 0 || nBlobSize % 2 != 0 ){
283+ * pzErrMsg = sqlite3_mprintf ("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d" , nBlobSize );
284+ return SQLITE_ERROR ;
285+ }
286+ nLeftoverBits = pBlob [nBlobSize - 1 ];
287+ * pDims = nBlobSize * 8 - nLeftoverBits ;
288+ * pDataSize = (* pDims + 7 ) / 8 ;
289+ }else {
290+ * pzErrMsg = sqlite3_mprintf ("invalid vector: unexpected type: %d" , * pType );
291+ return SQLITE_ERROR ;
292+ }
293+ return SQLITE_OK ;
294+ }
295+
256296int vectorParseSqliteBlobWithType (
257297 sqlite3_value * arg ,
258298 Vector * pVector ,
259299 char * * pzErrMsg
260300){
261301 const unsigned char * pBlob ;
262- size_t nBlobSize ;
302+ size_t nBlobSize , nDataSize ;
303+ int type , dims ;
263304
264305 assert ( sqlite3_value_type (arg ) == SQLITE_BLOB );
265306
266307 pBlob = sqlite3_value_blob (arg );
267308 nBlobSize = sqlite3_value_bytes (arg );
268- if ( nBlobSize % 2 == 1 ){
269- nBlobSize -- ;
309+ if ( vectorParseMeta ( pBlob , nBlobSize , & type , & dims , & nDataSize , pzErrMsg ) != SQLITE_OK ){
310+ return SQLITE_ERROR ;
270311 }
271312
272- if ( nBlobSize < vectorDataSize (pVector -> type , pVector -> dims ) ){
273- * pzErrMsg = sqlite3_mprintf ("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull" , pVector -> type , pVector -> dims , nBlobSize );
313+ if ( nDataSize != vectorDataSize (pVector -> type , pVector -> dims ) ){
314+ * pzErrMsg = sqlite3_mprintf (
315+ "invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull" ,
316+ pVector -> type ,
317+ pVector -> dims ,
318+ nDataSize ,
319+ vectorDataSize (pVector -> type , pVector -> dims )
320+ );
274321 return SQLITE_ERROR ;
275322 }
276323
277324 switch (pVector -> type ) {
278325 case VECTOR_TYPE_FLOAT32 :
279- vectorF32DeserializeFromBlob (pVector , pBlob , nBlobSize );
326+ vectorF32DeserializeFromBlob (pVector , pBlob , nDataSize );
280327 return 0 ;
281328 case VECTOR_TYPE_FLOAT64 :
282- vectorF64DeserializeFromBlob (pVector , pBlob , nBlobSize );
329+ vectorF64DeserializeFromBlob (pVector , pBlob , nDataSize );
330+ return 0 ;
331+ case VECTOR_TYPE_1BIT :
332+ vector1BitDeserializeFromBlob (pVector , pBlob , nDataSize );
283333 return 0 ;
284334 default :
285335 assert (0 );
@@ -298,15 +348,22 @@ int detectBlobVectorParameters(sqlite3_value *arg, int *pType, int *pDims, char
298348 if ( nBlobSize % 2 != 0 ){
299349 // we have trailing byte with explicit type definition
300350 * pType = pBlob [nBlobSize - 1 ];
351+ nBlobSize -- ;
301352 } else {
302353 // else, fallback to FLOAT32
303354 * pType = VECTOR_TYPE_FLOAT32 ;
304355 }
305356 if ( * pType == VECTOR_TYPE_FLOAT32 ){
306357 * pDims = nBlobSize / sizeof (float );
307- } else if ( * pType == VECTOR_TYPE_FLOAT64 ){
358+ }else if ( * pType == VECTOR_TYPE_FLOAT64 ){
308359 * pDims = nBlobSize / sizeof (double );
309- } else {
360+ }else if ( * pType == VECTOR_TYPE_1BIT ){
361+ if ( nBlobSize == 0 || nBlobSize % 2 != 0 ){
362+ * pzErrMsg = sqlite3_mprintf ("vector: malformed 1bit float: blob size must has even size (without last byte): size=%d" , nBlobSize );
363+ return -1 ;
364+ }
365+ * pDims = nBlobSize * 8 - pBlob [nBlobSize - 1 ];
366+ }else {
310367 * pzErrMsg = sqlite3_mprintf ("vector: unexpected binary type: got %d, expected %d or %d" , * pType , VECTOR_TYPE_FLOAT32 , VECTOR_TYPE_FLOAT64 );
311368 return -1 ;
312369 }
@@ -411,21 +468,55 @@ void vectorMarshalToText(
411468 }
412469}
413470
414- void vectorSerializeWithType (
471+ static int vectorMetaSize (VectorType type , VectorDims dims ){
472+ int nMetaSize = 0 ;
473+ int nDataSize ;
474+ if ( type == VECTOR_TYPE_FLOAT32 ){
475+ return 0 ;
476+ }else if ( type == VECTOR_TYPE_FLOAT64 ){
477+ return 1 ;
478+ }else if ( type == VECTOR_TYPE_1BIT ){
479+ nDataSize = vectorDataSize (type , dims );
480+ nMetaSize ++ ; // one byte which specify amount of leftover bits
481+ if ( nDataSize % 2 == 0 ){
482+ nMetaSize ++ ; // pad "leftover-bits" byte to the even length
483+ }
484+ nMetaSize ++ ; // one byte for vector type
485+ return nMetaSize ;
486+ }else {
487+ assert ( 0 );
488+ }
489+ }
490+
491+ static void vectorSerializeMeta (const Vector * pVector , size_t nDataSize , unsigned char * pBlob , size_t nBlobSize ){
492+ if ( pVector -> type == VECTOR_TYPE_FLOAT32 ){
493+ // no meta for f32 type as this is "default" vector type
494+ }else if ( pVector -> type == VECTOR_TYPE_FLOAT64 ){
495+ assert ( nDataSize % 2 == 0 );
496+ assert ( nBlobSize == nDataSize + 1 );
497+ pBlob [nBlobSize - 1 ] = VECTOR_TYPE_FLOAT64 ;
498+ }else if ( pVector -> type == VECTOR_TYPE_1BIT ){
499+ assert ( nBlobSize % 2 == 1 );
500+ assert ( nBlobSize >= 3 );
501+ pBlob [nBlobSize - 1 ] = VECTOR_TYPE_1BIT ;
502+ pBlob [nBlobSize - 2 ] = 8 * (nBlobSize - 1 ) - pVector -> dims ;
503+ }else {
504+ assert ( 0 );
505+ }
506+ }
507+
508+ void vectorSerializeWithMeta (
415509 sqlite3_context * context ,
416510 const Vector * pVector
417511){
418512 unsigned char * pBlob ;
419- size_t nBlobSize , nDataSize ;
513+ size_t nBlobSize , nDataSize , nMetaSize ;
420514
421515 assert ( pVector -> dims <= MAX_VECTOR_SZ );
422516
423517 nDataSize = vectorDataSize (pVector -> type , pVector -> dims );
424- nBlobSize = nDataSize ;
425- if ( pVector -> type != VECTOR_TYPE_FLOAT32 ){
426- nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2 );
427- }
428-
518+ nMetaSize = vectorMetaSize (pVector -> type , pVector -> dims );
519+ nBlobSize = nDataSize + nMetaSize ;
429520 if ( nBlobSize == 0 ){
430521 sqlite3_result_zeroblob (context , 0 );
431522 return ;
@@ -437,20 +528,20 @@ void vectorSerializeWithType(
437528 return ;
438529 }
439530
440- if ( pVector -> type != VECTOR_TYPE_FLOAT32 ){
441- pBlob [nBlobSize - 1 ] = pVector -> type ;
442- }
443-
444531 switch (pVector -> type ) {
445532 case VECTOR_TYPE_FLOAT32 :
446533 vectorF32SerializeToBlob (pVector , pBlob , nDataSize );
447534 break ;
448535 case VECTOR_TYPE_FLOAT64 :
449536 vectorF64SerializeToBlob (pVector , pBlob , nDataSize );
450537 break ;
538+ case VECTOR_TYPE_1BIT :
539+ vector1BitSerializeToBlob (pVector , pBlob , nDataSize );
540+ break ;
451541 default :
452542 assert (0 );
453543 }
544+ vectorSerializeMeta (pVector , nDataSize , pBlob , nBlobSize );
454545 sqlite3_result_blob (context , (char * )pBlob , nBlobSize , sqlite3_free );
455546}
456547
@@ -614,11 +705,15 @@ static void vectorFuncHintedType(
614705){
615706 char * pzErrMsg = NULL ;
616707 Vector * pVector = NULL , * pTarget = NULL ;
617- int type , dims ;
708+ int type , dims , typeHint = VECTOR_TYPE_FLOAT32 ;
618709 if ( argc < 1 ){
619710 goto out ;
620711 }
621- if ( detectVectorParameters (argv [0 ], targetType , & type , & dims , & pzErrMsg ) != 0 ){
712+ // simplification in order to support only parsing from text to f32 and f64 vectors
713+ if ( targetType == VECTOR_TYPE_FLOAT64 ){
714+ typeHint = targetType ;
715+ }
716+ if ( detectVectorParameters (argv [0 ], typeHint , & type , & dims , & pzErrMsg ) != 0 ){
622717 sqlite3_result_error (context , pzErrMsg , -1 );
623718 sqlite3_free (pzErrMsg );
624719 goto out ;
@@ -633,14 +728,14 @@ static void vectorFuncHintedType(
633728 goto out ;
634729 }
635730 if ( type == targetType ){
636- vectorSerializeWithType (context , pVector );
731+ vectorSerializeWithMeta (context , pVector );
637732 }else {
638733 pTarget = vectorContextAlloc (context , targetType , dims );
639734 if ( pTarget == NULL ){
640735 goto out ;
641736 }
642737 vectorConvert (pVector , pTarget );
643- vectorSerializeWithType (context , pTarget );
738+ vectorSerializeWithMeta (context , pTarget );
644739 }
645740out :
646741 if ( pVector != NULL ){
@@ -666,6 +761,14 @@ static void vector64Func(
666761 vectorFuncHintedType (context , argc , argv , VECTOR_TYPE_FLOAT64 );
667762}
668763
764+ static void vector1BitFunc (
765+ sqlite3_context * context ,
766+ int argc ,
767+ sqlite3_value * * argv
768+ ){
769+ vectorFuncHintedType (context , argc , argv , VECTOR_TYPE_1BIT );
770+ }
771+
669772/*
670773** Implementation of vector_extract(X) function.
671774*/
@@ -675,30 +778,44 @@ static void vectorExtractFunc(
675778 sqlite3_value * * argv
676779){
677780 char * pzErrMsg = NULL ;
678- Vector * pVector ;
781+ Vector * pVector = NULL , * pTarget = NULL ;
679782 unsigned i ;
680783 int type , dims ;
681784
682785 if ( argc < 1 ){
683- return ;
786+ goto out ;
684787 }
685788 if ( detectVectorParameters (argv [0 ], 0 , & type , & dims , & pzErrMsg ) != 0 ){
686789 sqlite3_result_error (context , pzErrMsg , -1 );
687790 sqlite3_free (pzErrMsg );
688- return ;
791+ goto out ;
689792 }
690793 pVector = vectorContextAlloc (context , type , dims );
691- if ( pVector == NULL ){
692- return ;
794+ if ( pVector == NULL ){
795+ goto out ;
693796 }
694797 if ( vectorParseWithType (argv [0 ], pVector , & pzErrMsg )< 0 ){
695798 sqlite3_result_error (context , pzErrMsg , -1 );
696799 sqlite3_free (pzErrMsg );
697- goto out_free ;
800+ goto out ;
801+ }
802+ if ( pVector -> type == VECTOR_TYPE_FLOAT32 || pVector -> type == VECTOR_TYPE_FLOAT64 ){
803+ vectorMarshalToText (context , pVector );
804+ }else {
805+ pTarget = vectorContextAlloc (context , VECTOR_TYPE_FLOAT32 , dims );
806+ if ( pTarget == NULL ){
807+ goto out ;
808+ }
809+ vectorConvert (pVector , pTarget );
810+ vectorMarshalToText (context , pTarget );
811+ }
812+ out :
813+ if ( pVector != NULL ){
814+ vectorFree (pVector );
815+ }
816+ if ( pTarget != NULL ){
817+ vectorFree (pTarget );
698818 }
699- vectorMarshalToText (context , pVector );
700- out_free :
701- vectorFree (pVector );
702819}
703820
704821/*
@@ -782,6 +899,7 @@ void sqlite3RegisterVectorFunctions(void){
782899 FUNCTION (vector , 1 , 0 , 0 , vector32Func ),
783900 FUNCTION (vector32 , 1 , 0 , 0 , vector32Func ),
784901 FUNCTION (vector64 , 1 , 0 , 0 , vector64Func ),
902+ FUNCTION (vector1bit , 1 , 0 , 0 , vector1BitFunc ),
785903 FUNCTION (vector_extract , 1 , 0 , 0 , vectorExtractFunc ),
786904 FUNCTION (vector_distance_cos , 2 , 0 , 0 , vectorDistanceCosFunc ),
787905
0 commit comments