Skip to content

Commit 1a9cab9

Browse files
committed
1bit quantized embeddings search: somehow working version
1 parent 2e696fe commit 1a9cab9

5 files changed

Lines changed: 166 additions & 61 deletions

File tree

libsql-sqlite3/src/vector.c

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){
7474
** Note that the vector object points to the blob so if
7575
** you free the blob, the vector becomes invalid.
7676
**/
77-
void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){
78-
pVector->type = type;
77+
void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){
7978
pVector->flags = VECTOR_FLAGS_STATIC;
80-
vectorInitFromBlob(pVector, pBlob, nBlobSize);
79+
pVector->type = type;
80+
pVector->dims = dims;
81+
pVector->data = pBlob;
8182
}
8283

8384
/*
@@ -479,6 +480,31 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo
479480
}
480481
}
481482

483+
void vectorConvert(const Vector *pFrom, Vector *pTo){
484+
int i;
485+
u8 *bitData;
486+
float *floatData;
487+
488+
assert( pFrom->dims == pTo->dims );
489+
490+
if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){
491+
floatData = pFrom->data;
492+
bitData = pTo->data;
493+
for(i = 0; i < pFrom->dims; i += 8){
494+
bitData[i / 8] = 0;
495+
}
496+
for(i = 0; i < pFrom->dims; i++){
497+
if( floatData[i] < 0 ){
498+
bitData[i / 8] &= ~(1 << (i & 7));
499+
}else{
500+
bitData[i / 8] |= (1 << (i & 7));
501+
}
502+
}
503+
}else{
504+
assert(0);
505+
}
506+
}
507+
482508
/**************************************************************************
483509
** SQL function implementations
484510
****************************************************************************/

libsql-sqlite3/src/vector1bit.c

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,16 @@ size_t vector1BitSerializeToBlob(
5656
unsigned char *pBlob,
5757
size_t nBlobSize
5858
){
59-
float *elems = pVector->data;
60-
unsigned char *pPtr = pBlob;
61-
size_t len = 0;
59+
u8 *elems = pVector->data;
60+
u8 *pPtr = pBlob;
6261
unsigned i;
6362

6463
assert( pVector->type == VECTOR_TYPE_1BIT );
6564
assert( pVector->dims <= MAX_VECTOR_SZ );
6665
assert( nBlobSize >= (pVector->dims + 7) / 8 );
6766

68-
for(i = 0; i < pVector->dims; i++){
69-
elems[i] = pPtr[i];
67+
for(i = 0; i < (pVector->dims + 7) / 8; i++){
68+
pPtr[i] = elems[i];
7069
}
7170
return (pVector->dims + 7) / 8;
7271
}
@@ -92,7 +91,7 @@ static int BitsCount[256] = {
9291
};
9392

9493
int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){
95-
int sum = 0;
94+
int diff = 0;
9695
u8 *e1 = v1->data;
9796
u8 *e2 = v2->data;
9897
int i;
@@ -101,10 +100,10 @@ int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){
101100
assert( v1->type == VECTOR_TYPE_1BIT );
102101
assert( v2->type == VECTOR_TYPE_1BIT );
103102

104-
for(i = 0; i < v1->dims; i++){
105-
sum += BitsCount[e1[i]&e2[i]];
103+
for(i = 0; i < v1->dims; i += 8){
104+
diff += BitsCount[e1[i/8] ^ e2[i/8]];
106105
}
107-
return sum;
106+
return diff;
108107
}
109108

110109
#endif /* !defined(SQLITE_OMIT_VECTOR) */

libsql-sqlite3/src/vectorIndex.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,10 @@ struct VectorParamName {
396396
};
397397

398398
static struct VectorParamName VECTOR_PARAM_NAMES[] = {
399-
{ "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN },
400-
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS },
401-
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 },
402-
{ "compress_neighbors", VECTOR_METRIC_TYPE_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT },
399+
{ "type", VECTOR_INDEX_TYPE_PARAM_ID, 0, "diskann", VECTOR_INDEX_TYPE_DISKANN },
400+
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "cosine", VECTOR_METRIC_TYPE_COS },
401+
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 },
402+
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "1bit", VECTOR_TYPE_1BIT },
403403
{ "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 },
404404
{ "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 },
405405
{ "insert_l", VECTOR_INSERT_L_PARAM_ID, 1, 0, 0 },

libsql-sqlite3/src/vectorInt.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,13 @@ int vectorParseSqliteBlob (sqlite3_value *, Vector *, char **);
102102
void vectorF32DeserializeFromBlob(Vector *, const unsigned char *, size_t);
103103
void vectorF64DeserializeFromBlob(Vector *, const unsigned char *, size_t);
104104

105-
void vectorInitStatic(Vector *, VectorType, const unsigned char *, size_t);
105+
void vectorInitStatic(Vector *, VectorType, VectorDims, void *);
106106
void vectorInitFromBlob(Vector *, const unsigned char *, size_t);
107107
void vectorF32InitFromBlob(Vector *, const unsigned char *, size_t);
108108
void vectorF64InitFromBlob(Vector *, const unsigned char *, size_t);
109109

110+
void vectorConvert(const Vector *, Vector *);
111+
110112
/* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */
111113
int detectVectorParameters(sqlite3_value *, int, int *, int *, char **);
112114

0 commit comments

Comments
 (0)