Skip to content

Commit 181464f

Browse files
committed
add support for 1bit vector functions
1 parent 4b562ed commit 181464f

4 files changed

Lines changed: 176 additions & 41 deletions

File tree

libsql-sqlite3/src/vector.c

Lines changed: 152 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4242
case VECTOR_TYPE_FLOAT64:
4343
return dims * sizeof(double);
4444
case VECTOR_TYPE_1BIT:
45-
assert( dims > 0 );
4645
return (dims + 7) / 8;
4746
default:
4847
assert(0);
@@ -253,33 +252,84 @@ static int vectorParseSqliteText(
253252
return -1;
254253
}
255254

255+
static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){
256+
int nLeftoverBits;
257+
258+
if( nBlobSize % 2 == 0 ){
259+
*pType = VECTOR_TYPE_FLOAT32;
260+
*pDims = nBlobSize / sizeof(float);
261+
*pDataSize = nBlobSize;
262+
return SQLITE_OK;
263+
}
264+
*pType = pBlob[nBlobSize - 1];
265+
nBlobSize--;
266+
267+
if( *pType == VECTOR_TYPE_FLOAT32 ){
268+
if( nBlobSize % 4 != 0 ){
269+
*pzErrMsg = sqlite3_mprintf("invalid vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize);
270+
return SQLITE_ERROR;
271+
}
272+
*pDims = nBlobSize / sizeof(float);
273+
*pDataSize = nBlobSize;
274+
}else if( *pType == VECTOR_TYPE_FLOAT64 ){
275+
if( nBlobSize % 8 != 0 ){
276+
*pzErrMsg = sqlite3_mprintf("invalid vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize);
277+
return SQLITE_ERROR;
278+
}
279+
*pDims = nBlobSize / sizeof(double);
280+
*pDataSize = nBlobSize;
281+
}else if( *pType == VECTOR_TYPE_1BIT ){
282+
if( nBlobSize == 0 || nBlobSize % 2 != 0 ){
283+
*pzErrMsg = sqlite3_mprintf("invalid vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize);
284+
return SQLITE_ERROR;
285+
}
286+
nLeftoverBits = pBlob[nBlobSize - 1];
287+
*pDims = nBlobSize * 8 - nLeftoverBits;
288+
*pDataSize = (*pDims + 7) / 8;
289+
}else{
290+
*pzErrMsg = sqlite3_mprintf("invalid vector: unexpected type: %d", *pType);
291+
return SQLITE_ERROR;
292+
}
293+
return SQLITE_OK;
294+
}
295+
256296
int vectorParseSqliteBlobWithType(
257297
sqlite3_value *arg,
258298
Vector *pVector,
259299
char **pzErrMsg
260300
){
261301
const unsigned char *pBlob;
262-
size_t nBlobSize;
302+
size_t nBlobSize, nDataSize;
303+
int type, dims;
263304

264305
assert( sqlite3_value_type(arg) == SQLITE_BLOB );
265306

266307
pBlob = sqlite3_value_blob(arg);
267308
nBlobSize = sqlite3_value_bytes(arg);
268-
if( nBlobSize % 2 == 1 ){
269-
nBlobSize--;
309+
if( vectorParseMeta(pBlob, nBlobSize, &type, &dims, &nDataSize, pzErrMsg) != SQLITE_OK ){
310+
return SQLITE_ERROR;
270311
}
271312

272-
if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){
273-
*pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize);
313+
if( nDataSize != vectorDataSize(pVector->type, pVector->dims) ){
314+
*pzErrMsg = sqlite3_mprintf(
315+
"invalid vector: unexpected data size bytes: type=%d, dims=%d, %ull != %ull",
316+
pVector->type,
317+
pVector->dims,
318+
nDataSize,
319+
vectorDataSize(pVector->type, pVector->dims)
320+
);
274321
return SQLITE_ERROR;
275322
}
276323

277324
switch (pVector->type) {
278325
case VECTOR_TYPE_FLOAT32:
279-
vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize);
326+
vectorF32DeserializeFromBlob(pVector, pBlob, nDataSize);
280327
return 0;
281328
case VECTOR_TYPE_FLOAT64:
282-
vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize);
329+
vectorF64DeserializeFromBlob(pVector, pBlob, nDataSize);
330+
return 0;
331+
case VECTOR_TYPE_1BIT:
332+
vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize);
283333
return 0;
284334
default:
285335
assert(0);
@@ -298,15 +348,22 @@ int detectBlobVectorParameters(sqlite3_value *arg, int *pType, int *pDims, char
298348
if( nBlobSize % 2 != 0 ){
299349
// we have trailing byte with explicit type definition
300350
*pType = pBlob[nBlobSize - 1];
351+
nBlobSize--;
301352
} else {
302353
// else, fallback to FLOAT32
303354
*pType = VECTOR_TYPE_FLOAT32;
304355
}
305356
if( *pType == VECTOR_TYPE_FLOAT32 ){
306357
*pDims = nBlobSize / sizeof(float);
307-
} else if( *pType == VECTOR_TYPE_FLOAT64 ){
358+
}else if( *pType == VECTOR_TYPE_FLOAT64 ){
308359
*pDims = nBlobSize / sizeof(double);
309-
} else{
360+
}else if( *pType == VECTOR_TYPE_1BIT ){
361+
if( nBlobSize == 0 || nBlobSize % 2 != 0 ){
362+
*pzErrMsg = sqlite3_mprintf("vector: malformed 1bit float: blob size must has even size (without last byte): size=%d", nBlobSize);
363+
return -1;
364+
}
365+
*pDims = nBlobSize * 8 - pBlob[nBlobSize - 1];
366+
}else{
310367
*pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: got %d, expected %d or %d", *pType, VECTOR_TYPE_FLOAT32, VECTOR_TYPE_FLOAT64);
311368
return -1;
312369
}
@@ -411,21 +468,55 @@ void vectorMarshalToText(
411468
}
412469
}
413470

414-
void vectorSerializeWithType(
471+
static int vectorMetaSize(VectorType type, VectorDims dims){
472+
int nMetaSize = 0;
473+
int nDataSize;
474+
if( type == VECTOR_TYPE_FLOAT32 ){
475+
return 0;
476+
}else if( type == VECTOR_TYPE_FLOAT64 ){
477+
return 1;
478+
}else if( type == VECTOR_TYPE_1BIT ){
479+
nDataSize = vectorDataSize(type, dims);
480+
nMetaSize++; // one byte which specify amount of leftover bits
481+
if( nDataSize % 2 == 0 ){
482+
nMetaSize++; // pad "leftover-bits" byte to the even length
483+
}
484+
nMetaSize++; // one byte for vector type
485+
return nMetaSize;
486+
}else{
487+
assert( 0 );
488+
}
489+
}
490+
491+
static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){
492+
if( pVector->type == VECTOR_TYPE_FLOAT32 ){
493+
// no meta for f32 type as this is "default" vector type
494+
}else if( pVector->type == VECTOR_TYPE_FLOAT64 ){
495+
assert( nDataSize % 2 == 0 );
496+
assert( nBlobSize == nDataSize + 1 );
497+
pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64;
498+
}else if( pVector->type == VECTOR_TYPE_1BIT ){
499+
assert( nBlobSize % 2 == 1 );
500+
assert( nBlobSize >= 3 );
501+
pBlob[nBlobSize - 1] = VECTOR_TYPE_1BIT;
502+
pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims;
503+
}else{
504+
assert( 0 );
505+
}
506+
}
507+
508+
void vectorSerializeWithMeta(
415509
sqlite3_context *context,
416510
const Vector *pVector
417511
){
418512
unsigned char *pBlob;
419-
size_t nBlobSize, nDataSize;
513+
size_t nBlobSize, nDataSize, nMetaSize;
420514

421515
assert( pVector->dims <= MAX_VECTOR_SZ );
422516

423517
nDataSize = vectorDataSize(pVector->type, pVector->dims);
424-
nBlobSize = nDataSize;
425-
if( pVector->type != VECTOR_TYPE_FLOAT32 ){
426-
nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2);
427-
}
428-
518+
nMetaSize = vectorMetaSize(pVector->type, pVector->dims);
519+
nBlobSize = nDataSize + nMetaSize;
429520
if( nBlobSize == 0 ){
430521
sqlite3_result_zeroblob(context, 0);
431522
return;
@@ -437,20 +528,20 @@ void vectorSerializeWithType(
437528
return;
438529
}
439530

440-
if( pVector->type != VECTOR_TYPE_FLOAT32 ){
441-
pBlob[nBlobSize - 1] = pVector->type;
442-
}
443-
444531
switch (pVector->type) {
445532
case VECTOR_TYPE_FLOAT32:
446533
vectorF32SerializeToBlob(pVector, pBlob, nDataSize);
447534
break;
448535
case VECTOR_TYPE_FLOAT64:
449536
vectorF64SerializeToBlob(pVector, pBlob, nDataSize);
450537
break;
538+
case VECTOR_TYPE_1BIT:
539+
vector1BitSerializeToBlob(pVector, pBlob, nDataSize);
540+
break;
451541
default:
452542
assert(0);
453543
}
544+
vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize);
454545
sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free);
455546
}
456547

@@ -614,11 +705,15 @@ static void vectorFuncHintedType(
614705
){
615706
char *pzErrMsg = NULL;
616707
Vector *pVector = NULL, *pTarget = NULL;
617-
int type, dims;
708+
int type, dims, typeHint = VECTOR_TYPE_FLOAT32;
618709
if( argc < 1 ){
619710
goto out;
620711
}
621-
if( detectVectorParameters(argv[0], targetType, &type, &dims, &pzErrMsg) != 0 ){
712+
// simplification in order to support only parsing from text to f32 and f64 vectors
713+
if( targetType == VECTOR_TYPE_FLOAT64 ){
714+
typeHint = targetType;
715+
}
716+
if( detectVectorParameters(argv[0], typeHint, &type, &dims, &pzErrMsg) != 0 ){
622717
sqlite3_result_error(context, pzErrMsg, -1);
623718
sqlite3_free(pzErrMsg);
624719
goto out;
@@ -633,14 +728,14 @@ static void vectorFuncHintedType(
633728
goto out;
634729
}
635730
if( type == targetType ){
636-
vectorSerializeWithType(context, pVector);
731+
vectorSerializeWithMeta(context, pVector);
637732
}else{
638733
pTarget = vectorContextAlloc(context, targetType, dims);
639734
if( pTarget == NULL ){
640735
goto out;
641736
}
642737
vectorConvert(pVector, pTarget);
643-
vectorSerializeWithType(context, pTarget);
738+
vectorSerializeWithMeta(context, pTarget);
644739
}
645740
out:
646741
if( pVector != NULL ){
@@ -666,6 +761,14 @@ static void vector64Func(
666761
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64);
667762
}
668763

764+
static void vector1BitFunc(
765+
sqlite3_context *context,
766+
int argc,
767+
sqlite3_value **argv
768+
){
769+
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_1BIT);
770+
}
771+
669772
/*
670773
** Implementation of vector_extract(X) function.
671774
*/
@@ -675,30 +778,44 @@ static void vectorExtractFunc(
675778
sqlite3_value **argv
676779
){
677780
char *pzErrMsg = NULL;
678-
Vector *pVector;
781+
Vector *pVector = NULL, *pTarget = NULL;
679782
unsigned i;
680783
int type, dims;
681784

682785
if( argc < 1 ){
683-
return;
786+
goto out;
684787
}
685788
if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){
686789
sqlite3_result_error(context, pzErrMsg, -1);
687790
sqlite3_free(pzErrMsg);
688-
return;
791+
goto out;
689792
}
690793
pVector = vectorContextAlloc(context, type, dims);
691-
if( pVector==NULL ){
692-
return;
794+
if( pVector == NULL ){
795+
goto out;
693796
}
694797
if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){
695798
sqlite3_result_error(context, pzErrMsg, -1);
696799
sqlite3_free(pzErrMsg);
697-
goto out_free;
800+
goto out;
801+
}
802+
if( pVector->type == VECTOR_TYPE_FLOAT32 || pVector->type == VECTOR_TYPE_FLOAT64 ){
803+
vectorMarshalToText(context, pVector);
804+
}else{
805+
pTarget = vectorContextAlloc(context, VECTOR_TYPE_FLOAT32, dims);
806+
if( pTarget == NULL ){
807+
goto out;
808+
}
809+
vectorConvert(pVector, pTarget);
810+
vectorMarshalToText(context, pTarget);
811+
}
812+
out:
813+
if( pVector != NULL ){
814+
vectorFree(pVector);
815+
}
816+
if( pTarget != NULL ){
817+
vectorFree(pTarget);
698818
}
699-
vectorMarshalToText(context, pVector);
700-
out_free:
701-
vectorFree(pVector);
702819
}
703820

704821
/*
@@ -782,6 +899,7 @@ void sqlite3RegisterVectorFunctions(void){
782899
FUNCTION(vector, 1, 0, 0, vector32Func),
783900
FUNCTION(vector32, 1, 0, 0, vector32Func),
784901
FUNCTION(vector64, 1, 0, 0, vector64Func),
902+
FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc),
785903
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
786904
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
787905

libsql-sqlite3/src/vector1bit.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,18 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){
124124
return diff;
125125
}
126126

127+
void vector1BitDeserializeFromBlob(
128+
Vector *pVector,
129+
const unsigned char *pBlob,
130+
size_t nBlobSize
131+
){
132+
u8 *elems = pVector->data;
133+
134+
assert( pVector->type == VECTOR_TYPE_1BIT );
135+
assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ );
136+
assert( nBlobSize >= (pVector->dims + 7) / 8 );
137+
138+
memcpy(elems, pBlob, (pVector->dims + 7) / 8);
139+
}
140+
127141
#endif /* !defined(SQLITE_OMIT_VECTOR) */

libsql-sqlite3/src/vectorIndex.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,12 @@ struct VectorColumnType {
378378
};
379379

380380
static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
381-
{ "FLOAT32", VECTOR_TYPE_FLOAT32 },
382-
{ "FLOAT64", VECTOR_TYPE_FLOAT64 },
383-
{ "F32_BLOB", VECTOR_TYPE_FLOAT32 },
384-
{ "F64_BLOB", VECTOR_TYPE_FLOAT64 }
381+
{ "FLOAT32", VECTOR_TYPE_FLOAT32 },
382+
{ "F32_BLOB", VECTOR_TYPE_FLOAT32 },
383+
{ "FLOAT64", VECTOR_TYPE_FLOAT64 },
384+
{ "F64_BLOB", VECTOR_TYPE_FLOAT64 },
385+
{ "FLOAT1BIT", VECTOR_TYPE_1BIT },
386+
{ "F1BIT_BLOB", VECTOR_TYPE_1BIT },
385387
};
386388

387389
/*

libsql-sqlite3/src/vectorInt.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,16 @@ double vectorF64DistanceL2(const Vector *, const Vector *);
9292
* LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob
9393
* By default, blob with even length will be treated as a f32 blob
9494
*/
95-
void vectorSerializeWithType(sqlite3_context *, const Vector *);
95+
void vectorSerializeWithMeta(sqlite3_context *, const Vector *);
9696

9797
/*
9898
* Parses Vector content from the blob; vector type and dimensions must be filled already
9999
*/
100100
int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **);
101101

102-
void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t);
103-
void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t);
102+
void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t);
103+
void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t);
104+
void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t);
104105

105106
void vectorInitStatic(Vector *, VectorType, VectorDims, void *);
106107
void vectorInitFromBlob(Vector *, const unsigned char *, size_t);

0 commit comments

Comments
 (0)