Skip to content

Commit 2e696fe

Browse files
committed
restructure search a bit in order to support compressed edges
1 parent 39e30ea commit 2e696fe

1 file changed

Lines changed: 118 additions & 50 deletions

File tree

libsql-sqlite3/src/vectordiskann.c

Lines changed: 118 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,19 @@ struct DiskAnnNode {
9898
* so caller which puts nodes in the context can forget about resource managmenet (context will take care of this)
9999
*/
100100
struct DiskAnnSearchCtx {
101-
const Vector *pQuery; /* initial query vector; user query for SELECT and row vector for INSERT */
102-
DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */
103-
double *aDistances; /* array of distances to the query vector */
104-
unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */
105-
unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */
106-
DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */
107-
unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */
108-
int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */
101+
const Vector *pNodeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */
102+
const Vector *pEdgeQuery; /* initial query vector; user query for SELECT and row vector for INSERT */
103+
DiskAnnNode **aCandidates; /* array of candidates ordered by distance to the query (ascending) */
104+
float *aDistances; /* array of distances to the query vector */
105+
unsigned int nCandidates; /* current size of aCandidates/aDistances arrays */
106+
unsigned int maxCandidates; /* max size of aCandidates/aDistances arrays */
107+
DiskAnnNode **aTopCandidates; /* top candidates with exact distance calculated */
108+
float *aTopDistances; /* top candidates exact distances */
109+
int nTopCandidates; /* current size of aTopCandidates/aTopDistances arrays */
110+
int maxTopCandidates; /* max size of aTopCandidates/aTopDistances arrays */
111+
DiskAnnNode *visitedList; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */
112+
unsigned int nUnvisited; /* amount of unvisited candidates in the aCadidates array */
113+
int blobMode; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */
109114
};
110115

111116
/**************************************************************************
@@ -805,6 +810,53 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){
805810
return rc;
806811
}
807812

813+
/**************************************************************************
814+
** Generic utilities
815+
**************************************************************************/
816+
817+
int distanceBufferInsertIdx(const float *aDistances, int nSize, int nMaxSize, float distance){
818+
int i;
819+
#ifdef SQLITE_DEBUG
820+
for(i = 0; i < nSize - 1; i++){
821+
assert(aDistances[i] <= aDistances[i + 1]);
822+
}
823+
#endif
824+
for(i = 0; i < nSize; i++){
825+
if( distance < aDistances[i] ){
826+
return i;
827+
}
828+
}
829+
return nSize < nMaxSize ? nSize : -1;
830+
}
831+
832+
void bufferInsert(void *aBuffer, int nSize, int nMaxSize, int iInsert, int nItemSize, const void *pItem, void *pLast) {
833+
int itemsToMove;
834+
835+
assert( nMaxSize > 0 && nItemSize > 0 );
836+
assert( nSize <= nMaxSize );
837+
assert( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize );
838+
839+
if( nSize == nMaxSize ){
840+
if( pLast != NULL ){
841+
memcpy(pLast, aBuffer + (nSize - 1) * nItemSize, nItemSize);
842+
}
843+
nSize--;
844+
}
845+
itemsToMove = nSize - iInsert;
846+
memmove(aBuffer + (iInsert + 1) * nItemSize, aBuffer + iInsert * nItemSize, itemsToMove * nItemSize);
847+
memcpy(aBuffer + iInsert * nItemSize, pItem, nItemSize);
848+
}
849+
850+
void bufferDelete(void *aBuffer, int nSize, int iDelete, int nItemSize) {
851+
int itemsToMove;
852+
853+
assert( nItemSize > 0 );
854+
assert( 0 <= iDelete && iDelete < nSize );
855+
856+
itemsToMove = nSize - iDelete - 1;
857+
memmove(aBuffer + iDelete * nItemSize, aBuffer + (iDelete + 1) * nItemSize, itemsToMove * nItemSize);
858+
}
859+
808860
/**************************************************************************
809861
** DiskANN internals
810862
**************************************************************************/
@@ -841,16 +893,21 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){
841893
sqlite3_free(pNode);
842894
}
843895

844-
static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, unsigned int maxCandidates, int blobMode){
845-
pCtx->pQuery = pQuery;
896+
static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, int maxCandidates, int topCandidates, int blobMode){
897+
pCtx->pNodeQuery = pQuery;
898+
pCtx->pEdgeQuery = pQuery;
846899
pCtx->aDistances = sqlite3_malloc(maxCandidates * sizeof(double));
847900
pCtx->aCandidates = sqlite3_malloc(maxCandidates * sizeof(DiskAnnNode*));
848901
pCtx->nCandidates = 0;
849902
pCtx->maxCandidates = maxCandidates;
903+
pCtx->aTopDistances = sqlite3_malloc(topCandidates * sizeof(double));
904+
pCtx->aTopCandidates = sqlite3_malloc(topCandidates * sizeof(DiskAnnNode*));
905+
pCtx->nTopCandidates = 0;
906+
pCtx->maxTopCandidates = topCandidates;
850907
pCtx->visitedList = NULL;
851908
pCtx->nUnvisited = 0;
852909
pCtx->blobMode = blobMode;
853-
if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL ){
910+
if( pCtx->aDistances == NULL || pCtx->aCandidates == NULL || pCtx->aTopDistances == NULL || pCtx->aTopCandidates == NULL ){
854911
goto out_oom;
855912
}
856913
return SQLITE_OK;
@@ -861,6 +918,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un
861918
if( pCtx->aCandidates != NULL ){
862919
sqlite3_free(pCtx->aCandidates);
863920
}
921+
if( pCtx->aTopDistances != NULL ){
922+
sqlite3_free(pCtx->aTopDistances);
923+
}
924+
if( pCtx->aTopCandidates != NULL ){
925+
sqlite3_free(pCtx->aTopCandidates);
926+
}
864927
return SQLITE_NOMEM_BKPT;
865928
}
866929

@@ -884,6 +947,8 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){
884947
}
885948
sqlite3_free(pCtx->aCandidates);
886949
sqlite3_free(pCtx->aDistances);
950+
sqlite3_free(pCtx->aTopCandidates);
951+
sqlite3_free(pCtx->aTopDistances);
887952
}
888953

889954
// check if we visited this node earlier
@@ -925,7 +990,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const
925990
}
926991

927992
// mark node as visited and put it in the head of visitedList
928-
static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode){
993+
static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNode, float distance){
994+
int iInsert;
995+
929996
assert( pCtx->nUnvisited > 0 );
930997
assert( pNode->visited == 0 );
931998

@@ -934,56 +1001,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo
9341001

9351002
pNode->pNext = pCtx->visitedList;
9361003
pCtx->visitedList = pNode;
1004+
1005+
iInsert = distanceBufferInsertIdx(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, distance);
1006+
if( iInsert < 0 ){
1007+
return;
1008+
}
1009+
bufferInsert(pCtx->aTopCandidates, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(DiskAnnNode*), &pNode, NULL);
1010+
bufferInsert(pCtx->aTopDistances, pCtx->nTopCandidates, pCtx->maxTopCandidates, iInsert, sizeof(float), &distance, NULL);
1011+
pCtx->nTopCandidates = MIN(pCtx->nTopCandidates + 1, pCtx->maxTopCandidates);
9371012
}
9381013

9391014
static int diskAnnSearchCtxHasUnvisited(const DiskAnnSearchCtx *pCtx){
9401015
return pCtx->nUnvisited > 0;
9411016
}
9421017

943-
static DiskAnnNode* diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i){
1018+
static void diskAnnSearchCtxGetCandidate(DiskAnnSearchCtx *pCtx, int i, DiskAnnNode **ppNode, float *pDistance){
9441019
assert( 0 <= i && i < pCtx->nCandidates );
945-
return pCtx->aCandidates[i];
1020+
*ppNode = pCtx->aCandidates[i];
1021+
*pDistance = pCtx->aDistances[i];
9461022
}
9471023

9481024
static void diskAnnSearchCtxDeleteCandidate(DiskAnnSearchCtx *pCtx, int iDelete){
9491025
int i;
950-
assert( 0 <= iDelete && iDelete < pCtx->nCandidates );
9511026
assert( pCtx->nUnvisited > 0 );
9521027
assert( !pCtx->aCandidates[iDelete]->visited );
9531028
assert( pCtx->aCandidates[iDelete]->pBlobSpot == NULL );
9541029

9551030
diskAnnNodeFree(pCtx->aCandidates[iDelete]);
1031+
bufferDelete(pCtx->aCandidates, pCtx->nCandidates, iDelete, sizeof(DiskAnnNode*));
1032+
bufferDelete(pCtx->aDistances, pCtx->nCandidates, iDelete, sizeof(float));
9561033

957-
for(i = iDelete + 1; i < pCtx->nCandidates; i++){
958-
pCtx->aCandidates[i - 1] = pCtx->aCandidates[i];
959-
pCtx->aDistances[i - 1] = pCtx->aDistances[i];
960-
}
9611034
pCtx->nCandidates--;
9621035
pCtx->nUnvisited--;
9631036
}
9641037

965-
static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float candidateDist){
966-
int i;
967-
assert( 0 <= iInsert && iInsert <= pCtx->nCandidates && iInsert < pCtx->maxCandidates );
968-
if( pCtx->nCandidates < pCtx->maxCandidates ){
969-
pCtx->nCandidates++;
970-
} else {
971-
DiskAnnNode *pLast = pCtx->aCandidates[pCtx->nCandidates - 1];
972-
if( !pLast->visited ){
973-
// since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node
974-
assert( pLast->pBlobSpot == NULL );
975-
pCtx->nUnvisited--;
976-
diskAnnNodeFree(pLast);
977-
}
978-
}
979-
// Shift the candidates to the right to make space for the new one.
980-
for(i = pCtx->nCandidates - 1; i > iInsert; i--){
981-
pCtx->aCandidates[i] = pCtx->aCandidates[i - 1];
982-
pCtx->aDistances[i] = pCtx->aDistances[i - 1];
1038+
static void diskAnnSearchCtxInsertCandidate(DiskAnnSearchCtx *pCtx, int iInsert, DiskAnnNode* pCandidate, float distance){
1039+
DiskAnnNode *pLast = NULL;
1040+
bufferInsert(pCtx->aCandidates, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(DiskAnnNode*), &pCandidate, &pLast);
1041+
bufferInsert(pCtx->aDistances, pCtx->nCandidates, pCtx->maxCandidates, iInsert, sizeof(float), &distance, NULL);
1042+
pCtx->nCandidates = MIN(pCtx->nCandidates + 1, pCtx->maxCandidates);
1043+
if( pLast != NULL && !pLast->visited ){
1044+
// since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node
1045+
assert( pLast->pBlobSpot == NULL );
1046+
pCtx->nUnvisited--;
1047+
diskAnnNodeFree(pLast);
9831048
}
984-
// Insert the new candidate.
985-
pCtx->aCandidates[iInsert] = pCandidate;
986-
pCtx->aDistances[iInsert] = candidateDist;
9871049
pCtx->nUnvisited++;
9881050
}
9891051

@@ -1131,7 +1193,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11311193
}
11321194

11331195
nodeBinVector(pIndex, start->pBlobSpot, &startVector);
1134-
startDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &startVector);
1196+
startDistance = diskAnnVectorDistance(pIndex, pCtx->pNodeQuery, &startVector);
11351197

11361198
if( pCtx->blobMode == DISKANN_BLOB_READONLY ){
11371199
assert( start->pBlobSpot != NULL );
@@ -1148,8 +1210,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11481210
Vector vCandidate;
11491211
DiskAnnNode *pCandidate;
11501212
BlobSpot *pCandidateBlob;
1213+
float distance;
11511214
int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx(pCtx);
1152-
pCandidate = diskAnnSearchCtxGetCandidate(pCtx, iCandidate);
1215+
diskAnnSearchCtxGetCandidate(pCtx, iCandidate, &pCandidate, &distance);
11531216

11541217
rc = SQLITE_OK;
11551218
if( pReusableBlobSpot != NULL ){
@@ -1177,13 +1240,18 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11771240
goto out;
11781241
}
11791242

1180-
diskAnnSearchCtxMarkVisited(pCtx, pCandidate);
1181-
11821243
nVisited += 1;
11831244
DiskAnnTrace(("visiting candidate(%d): id=%lld\n", nVisited, pCandidate->nRowid));
11841245
nodeBinVector(pIndex, pCandidateBlob, &vCandidate);
11851246
nEdges = nodeBinEdges(pIndex, pCandidateBlob);
11861247

1248+
// if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it
1249+
if( pCtx->pNodeQuery != pCtx->pEdgeQuery ){
1250+
distance = diskAnnVectorDistance(pIndex, &vCandidate, pCtx->pNodeQuery);
1251+
}
1252+
1253+
diskAnnSearchCtxMarkVisited(pCtx, pCandidate, distance);
1254+
11871255
for(i = 0; i < nEdges; i++){
11881256
u64 edgeRowid;
11891257
Vector edgeVector;
@@ -1195,7 +1263,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11951263
continue;
11961264
}
11971265

1198-
edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pQuery, &edgeVector);
1266+
edgeDistance = diskAnnVectorDistance(pIndex, pCtx->pEdgeQuery, &edgeVector);
11991267
iInsert = diskAnnSearchCtxShouldAddCandidate(pIndex, pCtx, edgeDistance);
12001268
if( iInsert < 0 ){
12011269
continue;
@@ -1272,7 +1340,7 @@ int diskAnnSearch(
12721340
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to select start node for search");
12731341
return rc;
12741342
}
1275-
rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, DISKANN_BLOB_READONLY);
1343+
rc = diskAnnSearchCtxInit(&ctx, pVector, pIndex->searchL, k, DISKANN_BLOB_READONLY);
12761344
if( rc != SQLITE_OK ){
12771345
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to initialize search context");
12781346
goto out;
@@ -1281,17 +1349,17 @@ int diskAnnSearch(
12811349
if( rc != SQLITE_OK ){
12821350
goto out;
12831351
}
1284-
nOutRows = MIN(k, ctx.nCandidates);
1352+
nOutRows = MIN(k, ctx.nTopCandidates);
12851353
rc = vectorOutRowsAlloc(pIndex->db, pRows, nOutRows, pKey->nKeyColumns, vectorIdxKeyRowidLike(pKey));
12861354
if( rc != SQLITE_OK ){
12871355
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to allocate output rows");
12881356
goto out;
12891357
}
12901358
for(i = 0; i < nOutRows; i++){
12911359
if( pRows->aIntValues != NULL ){
1292-
rc = vectorOutRowsPut(pRows, i, 0, &ctx.aCandidates[i]->nRowid, NULL);
1360+
rc = vectorOutRowsPut(pRows, i, 0, &ctx.aTopCandidates[i]->nRowid, NULL);
12931361
}else{
1294-
rc = diskAnnGetShadowRowKeys(pIndex, ctx.aCandidates[i]->nRowid, pKey, pRows, i);
1362+
rc = diskAnnGetShadowRowKeys(pIndex, ctx.aTopCandidates[i]->nRowid, pKey, pRows, i);
12951363
}
12961364
if( rc != SQLITE_OK ){
12971365
*pzErrMsg = sqlite3_mprintf("vector index(search): failed to put result in the output row");
@@ -1327,7 +1395,7 @@ int diskAnnInsert(
13271395

13281396
DiskAnnTrace(("diskAnnInset started\n"));
13291397

1330-
rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, DISKANN_BLOB_WRITABLE);
1398+
rc = diskAnnSearchCtxInit(&ctx, pVectorInRow->pVector, pIndex->insertL, 1, DISKANN_BLOB_WRITABLE);
13311399
if( rc != SQLITE_OK ){
13321400
*pzErrMsg = sqlite3_mprintf("vector index(insert): failed to initialize search context");
13331401
return rc;

0 commit comments

Comments
 (0)