@@ -98,14 +98,19 @@ struct DiskAnnNode {
9898 * so caller which puts nodes in the context can forget about resource managmenet (context will take care of this)
9999*/
100100struct DiskAnnSearchCtx {
101- const Vector * pQuery ; /* initial query vector; user query for SELECT and row vector for INSERT */
102- DiskAnnNode * * aCandidates ; /* array of candidates ordered by distance to the query (ascending) */
103- double * aDistances ; /* array of distances to the query vector */
104- unsigned int nCandidates ; /* current size of aCandidates/aDistances arrays */
105- unsigned int maxCandidates ; /* max size of aCandidates/aDistances arrays */
106- DiskAnnNode * visitedList ; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */
107- unsigned int nUnvisited ; /* amount of unvisited candidates in the aCadidates array */
108- int blobMode ; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */
101+ const Vector * pNodeQuery ; /* initial query vector; user query for SELECT and row vector for INSERT */
102+ const Vector * pEdgeQuery ; /* initial query vector; user query for SELECT and row vector for INSERT */
103+ DiskAnnNode * * aCandidates ; /* array of candidates ordered by distance to the query (ascending) */
104+ float * aDistances ; /* array of distances to the query vector */
105+ unsigned int nCandidates ; /* current size of aCandidates/aDistances arrays */
106+ unsigned int maxCandidates ; /* max size of aCandidates/aDistances arrays */
107+ DiskAnnNode * * aTopCandidates ; /* top candidates with exact distance calculated */
108+ float * aTopDistances ; /* top candidates exact distances */
109+ int nTopCandidates ; /* current size of aTopCandidates/aTopDistances arrays */
110+ int maxTopCandidates ; /* max size of aTopCandidates/aTopDistances arrays */
111+ DiskAnnNode * visitedList ; /* list of all visited candidates (so, candidates from aCandidates array either got replaced or moved to the visited list) */
112+ unsigned int nUnvisited ; /* amount of unvisited candidates in the aCadidates array */
113+ int blobMode ; /* DISKANN_BLOB_READONLY if we wont modify node blobs; DISKANN_BLOB_WRITABLE - otherwise */
109114};
110115
111116/**************************************************************************
@@ -805,6 +810,53 @@ static int diskAnnDeleteShadowRow(const DiskAnnIndex *pIndex, i64 nRowid){
805810 return rc ;
806811}
807812
813+ /**************************************************************************
814+ ** Generic utilities
815+ **************************************************************************/
816+
817+ int distanceBufferInsertIdx (const float * aDistances , int nSize , int nMaxSize , float distance ){
818+ int i ;
819+ #ifdef SQLITE_DEBUG
820+ for (i = 0 ; i < nSize - 1 ; i ++ ){
821+ assert (aDistances [i ] <= aDistances [i + 1 ]);
822+ }
823+ #endif
824+ for (i = 0 ; i < nSize ; i ++ ){
825+ if ( distance < aDistances [i ] ){
826+ return i ;
827+ }
828+ }
829+ return nSize < nMaxSize ? nSize : -1 ;
830+ }
831+
832+ void bufferInsert (void * aBuffer , int nSize , int nMaxSize , int iInsert , int nItemSize , const void * pItem , void * pLast ) {
833+ int itemsToMove ;
834+
835+ assert ( nMaxSize > 0 && nItemSize > 0 );
836+ assert ( nSize <= nMaxSize );
837+ assert ( 0 <= iInsert && iInsert <= nSize && iInsert < nMaxSize );
838+
839+ if ( nSize == nMaxSize ){
840+ if ( pLast != NULL ){
841+ memcpy (pLast , aBuffer + (nSize - 1 ) * nItemSize , nItemSize );
842+ }
843+ nSize -- ;
844+ }
845+ itemsToMove = nSize - iInsert ;
846+ memmove (aBuffer + (iInsert + 1 ) * nItemSize , aBuffer + iInsert * nItemSize , itemsToMove * nItemSize );
847+ memcpy (aBuffer + iInsert * nItemSize , pItem , nItemSize );
848+ }
849+
850+ void bufferDelete (void * aBuffer , int nSize , int iDelete , int nItemSize ) {
851+ int itemsToMove ;
852+
853+ assert ( nItemSize > 0 );
854+ assert ( 0 <= iDelete && iDelete < nSize );
855+
856+ itemsToMove = nSize - iDelete - 1 ;
857+ memmove (aBuffer + iDelete * nItemSize , aBuffer + (iDelete + 1 ) * nItemSize , itemsToMove * nItemSize );
858+ }
859+
808860/**************************************************************************
809861** DiskANN internals
810862**************************************************************************/
@@ -841,16 +893,21 @@ static void diskAnnNodeFree(DiskAnnNode *pNode){
841893 sqlite3_free (pNode );
842894}
843895
844- static int diskAnnSearchCtxInit (DiskAnnSearchCtx * pCtx , const Vector * pQuery , unsigned int maxCandidates , int blobMode ){
845- pCtx -> pQuery = pQuery ;
896+ static int diskAnnSearchCtxInit (DiskAnnSearchCtx * pCtx , const Vector * pQuery , int maxCandidates , int topCandidates , int blobMode ){
897+ pCtx -> pNodeQuery = pQuery ;
898+ pCtx -> pEdgeQuery = pQuery ;
846899 pCtx -> aDistances = sqlite3_malloc (maxCandidates * sizeof (double ));
847900 pCtx -> aCandidates = sqlite3_malloc (maxCandidates * sizeof (DiskAnnNode * ));
848901 pCtx -> nCandidates = 0 ;
849902 pCtx -> maxCandidates = maxCandidates ;
903+ pCtx -> aTopDistances = sqlite3_malloc (topCandidates * sizeof (double ));
904+ pCtx -> aTopCandidates = sqlite3_malloc (topCandidates * sizeof (DiskAnnNode * ));
905+ pCtx -> nTopCandidates = 0 ;
906+ pCtx -> maxTopCandidates = topCandidates ;
850907 pCtx -> visitedList = NULL ;
851908 pCtx -> nUnvisited = 0 ;
852909 pCtx -> blobMode = blobMode ;
853- if ( pCtx -> aDistances == NULL || pCtx -> aCandidates == NULL ){
910+ if ( pCtx -> aDistances == NULL || pCtx -> aCandidates == NULL || pCtx -> aTopDistances == NULL || pCtx -> aTopCandidates == NULL ){
854911 goto out_oom ;
855912 }
856913 return SQLITE_OK ;
@@ -861,6 +918,12 @@ static int diskAnnSearchCtxInit(DiskAnnSearchCtx *pCtx, const Vector* pQuery, un
861918 if ( pCtx -> aCandidates != NULL ){
862919 sqlite3_free (pCtx -> aCandidates );
863920 }
921+ if ( pCtx -> aTopDistances != NULL ){
922+ sqlite3_free (pCtx -> aTopDistances );
923+ }
924+ if ( pCtx -> aTopCandidates != NULL ){
925+ sqlite3_free (pCtx -> aTopCandidates );
926+ }
864927 return SQLITE_NOMEM_BKPT ;
865928}
866929
@@ -884,6 +947,8 @@ static void diskAnnSearchCtxDeinit(DiskAnnSearchCtx *pCtx){
884947 }
885948 sqlite3_free (pCtx -> aCandidates );
886949 sqlite3_free (pCtx -> aDistances );
950+ sqlite3_free (pCtx -> aTopCandidates );
951+ sqlite3_free (pCtx -> aTopDistances );
887952}
888953
889954// check if we visited this node earlier
@@ -925,7 +990,9 @@ static int diskAnnSearchCtxShouldAddCandidate(const DiskAnnIndex *pIndex, const
925990}
926991
927992// mark node as visited and put it in the head of visitedList
928- static void diskAnnSearchCtxMarkVisited (DiskAnnSearchCtx * pCtx , DiskAnnNode * pNode ){
993+ static void diskAnnSearchCtxMarkVisited (DiskAnnSearchCtx * pCtx , DiskAnnNode * pNode , float distance ){
994+ int iInsert ;
995+
929996 assert ( pCtx -> nUnvisited > 0 );
930997 assert ( pNode -> visited == 0 );
931998
@@ -934,56 +1001,51 @@ static void diskAnnSearchCtxMarkVisited(DiskAnnSearchCtx *pCtx, DiskAnnNode *pNo
9341001
9351002 pNode -> pNext = pCtx -> visitedList ;
9361003 pCtx -> visitedList = pNode ;
1004+
1005+ iInsert = distanceBufferInsertIdx (pCtx -> aTopDistances , pCtx -> nTopCandidates , pCtx -> maxTopCandidates , distance );
1006+ if ( iInsert < 0 ){
1007+ return ;
1008+ }
1009+ bufferInsert (pCtx -> aTopCandidates , pCtx -> nTopCandidates , pCtx -> maxTopCandidates , iInsert , sizeof (DiskAnnNode * ), & pNode , NULL );
1010+ bufferInsert (pCtx -> aTopDistances , pCtx -> nTopCandidates , pCtx -> maxTopCandidates , iInsert , sizeof (float ), & distance , NULL );
1011+ pCtx -> nTopCandidates = MIN (pCtx -> nTopCandidates + 1 , pCtx -> maxTopCandidates );
9371012}
9381013
9391014static int diskAnnSearchCtxHasUnvisited (const DiskAnnSearchCtx * pCtx ){
9401015 return pCtx -> nUnvisited > 0 ;
9411016}
9421017
943- static DiskAnnNode * diskAnnSearchCtxGetCandidate (DiskAnnSearchCtx * pCtx , int i ){
1018+ static void diskAnnSearchCtxGetCandidate (DiskAnnSearchCtx * pCtx , int i , DiskAnnNode * * ppNode , float * pDistance ){
9441019 assert ( 0 <= i && i < pCtx -> nCandidates );
945- return pCtx -> aCandidates [i ];
1020+ * ppNode = pCtx -> aCandidates [i ];
1021+ * pDistance = pCtx -> aDistances [i ];
9461022}
9471023
9481024static void diskAnnSearchCtxDeleteCandidate (DiskAnnSearchCtx * pCtx , int iDelete ){
9491025 int i ;
950- assert ( 0 <= iDelete && iDelete < pCtx -> nCandidates );
9511026 assert ( pCtx -> nUnvisited > 0 );
9521027 assert ( !pCtx -> aCandidates [iDelete ]-> visited );
9531028 assert ( pCtx -> aCandidates [iDelete ]-> pBlobSpot == NULL );
9541029
9551030 diskAnnNodeFree (pCtx -> aCandidates [iDelete ]);
1031+ bufferDelete (pCtx -> aCandidates , pCtx -> nCandidates , iDelete , sizeof (DiskAnnNode * ));
1032+ bufferDelete (pCtx -> aDistances , pCtx -> nCandidates , iDelete , sizeof (float ));
9561033
957- for (i = iDelete + 1 ; i < pCtx -> nCandidates ; i ++ ){
958- pCtx -> aCandidates [i - 1 ] = pCtx -> aCandidates [i ];
959- pCtx -> aDistances [i - 1 ] = pCtx -> aDistances [i ];
960- }
9611034 pCtx -> nCandidates -- ;
9621035 pCtx -> nUnvisited -- ;
9631036}
9641037
965- static void diskAnnSearchCtxInsertCandidate (DiskAnnSearchCtx * pCtx , int iInsert , DiskAnnNode * pCandidate , float candidateDist ){
966- int i ;
967- assert ( 0 <= iInsert && iInsert <= pCtx -> nCandidates && iInsert < pCtx -> maxCandidates );
968- if ( pCtx -> nCandidates < pCtx -> maxCandidates ){
969- pCtx -> nCandidates ++ ;
970- } else {
971- DiskAnnNode * pLast = pCtx -> aCandidates [pCtx -> nCandidates - 1 ];
972- if ( !pLast -> visited ){
973- // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node
974- assert ( pLast -> pBlobSpot == NULL );
975- pCtx -> nUnvisited -- ;
976- diskAnnNodeFree (pLast );
977- }
978- }
979- // Shift the candidates to the right to make space for the new one.
980- for (i = pCtx -> nCandidates - 1 ; i > iInsert ; i -- ){
981- pCtx -> aCandidates [i ] = pCtx -> aCandidates [i - 1 ];
982- pCtx -> aDistances [i ] = pCtx -> aDistances [i - 1 ];
1038+ static void diskAnnSearchCtxInsertCandidate (DiskAnnSearchCtx * pCtx , int iInsert , DiskAnnNode * pCandidate , float distance ){
1039+ DiskAnnNode * pLast = NULL ;
1040+ bufferInsert (pCtx -> aCandidates , pCtx -> nCandidates , pCtx -> maxCandidates , iInsert , sizeof (DiskAnnNode * ), & pCandidate , & pLast );
1041+ bufferInsert (pCtx -> aDistances , pCtx -> nCandidates , pCtx -> maxCandidates , iInsert , sizeof (float ), & distance , NULL );
1042+ pCtx -> nCandidates = MIN (pCtx -> nCandidates + 1 , pCtx -> maxCandidates );
1043+ if ( pLast != NULL && !pLast -> visited ){
1044+ // since pLast is not visited it should have uninitialized pBlobSpot - so it's safe to completely free the node
1045+ assert ( pLast -> pBlobSpot == NULL );
1046+ pCtx -> nUnvisited -- ;
1047+ diskAnnNodeFree (pLast );
9831048 }
984- // Insert the new candidate.
985- pCtx -> aCandidates [iInsert ] = pCandidate ;
986- pCtx -> aDistances [iInsert ] = candidateDist ;
9871049 pCtx -> nUnvisited ++ ;
9881050}
9891051
@@ -1131,7 +1193,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11311193 }
11321194
11331195 nodeBinVector (pIndex , start -> pBlobSpot , & startVector );
1134- startDistance = diskAnnVectorDistance (pIndex , pCtx -> pQuery , & startVector );
1196+ startDistance = diskAnnVectorDistance (pIndex , pCtx -> pNodeQuery , & startVector );
11351197
11361198 if ( pCtx -> blobMode == DISKANN_BLOB_READONLY ){
11371199 assert ( start -> pBlobSpot != NULL );
@@ -1148,8 +1210,9 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11481210 Vector vCandidate ;
11491211 DiskAnnNode * pCandidate ;
11501212 BlobSpot * pCandidateBlob ;
1213+ float distance ;
11511214 int iCandidate = diskAnnSearchCtxFindClosestCandidateIdx (pCtx );
1152- pCandidate = diskAnnSearchCtxGetCandidate (pCtx , iCandidate );
1215+ diskAnnSearchCtxGetCandidate (pCtx , iCandidate , & pCandidate , & distance );
11531216
11541217 rc = SQLITE_OK ;
11551218 if ( pReusableBlobSpot != NULL ){
@@ -1177,13 +1240,18 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11771240 goto out ;
11781241 }
11791242
1180- diskAnnSearchCtxMarkVisited (pCtx , pCandidate );
1181-
11821243 nVisited += 1 ;
11831244 DiskAnnTrace (("visiting candidate(%d): id=%lld\n" , nVisited , pCandidate -> nRowid ));
11841245 nodeBinVector (pIndex , pCandidateBlob , & vCandidate );
11851246 nEdges = nodeBinEdges (pIndex , pCandidateBlob );
11861247
1248+ // if pNodeQuery != pEdgeQuery then distance from aDistances is approximate and we must recalculate it
1249+ if ( pCtx -> pNodeQuery != pCtx -> pEdgeQuery ){
1250+ distance = diskAnnVectorDistance (pIndex , & vCandidate , pCtx -> pNodeQuery );
1251+ }
1252+
1253+ diskAnnSearchCtxMarkVisited (pCtx , pCandidate , distance );
1254+
11871255 for (i = 0 ; i < nEdges ; i ++ ){
11881256 u64 edgeRowid ;
11891257 Vector edgeVector ;
@@ -1195,7 +1263,7 @@ static int diskAnnSearchInternal(DiskAnnIndex *pIndex, DiskAnnSearchCtx *pCtx, u
11951263 continue ;
11961264 }
11971265
1198- edgeDistance = diskAnnVectorDistance (pIndex , pCtx -> pQuery , & edgeVector );
1266+ edgeDistance = diskAnnVectorDistance (pIndex , pCtx -> pEdgeQuery , & edgeVector );
11991267 iInsert = diskAnnSearchCtxShouldAddCandidate (pIndex , pCtx , edgeDistance );
12001268 if ( iInsert < 0 ){
12011269 continue ;
@@ -1272,7 +1340,7 @@ int diskAnnSearch(
12721340 * pzErrMsg = sqlite3_mprintf ("vector index(search): failed to select start node for search" );
12731341 return rc ;
12741342 }
1275- rc = diskAnnSearchCtxInit (& ctx , pVector , pIndex -> searchL , DISKANN_BLOB_READONLY );
1343+ rc = diskAnnSearchCtxInit (& ctx , pVector , pIndex -> searchL , k , DISKANN_BLOB_READONLY );
12761344 if ( rc != SQLITE_OK ){
12771345 * pzErrMsg = sqlite3_mprintf ("vector index(search): failed to initialize search context" );
12781346 goto out ;
@@ -1281,17 +1349,17 @@ int diskAnnSearch(
12811349 if ( rc != SQLITE_OK ){
12821350 goto out ;
12831351 }
1284- nOutRows = MIN (k , ctx .nCandidates );
1352+ nOutRows = MIN (k , ctx .nTopCandidates );
12851353 rc = vectorOutRowsAlloc (pIndex -> db , pRows , nOutRows , pKey -> nKeyColumns , vectorIdxKeyRowidLike (pKey ));
12861354 if ( rc != SQLITE_OK ){
12871355 * pzErrMsg = sqlite3_mprintf ("vector index(search): failed to allocate output rows" );
12881356 goto out ;
12891357 }
12901358 for (i = 0 ; i < nOutRows ; i ++ ){
12911359 if ( pRows -> aIntValues != NULL ){
1292- rc = vectorOutRowsPut (pRows , i , 0 , & ctx .aCandidates [i ]-> nRowid , NULL );
1360+ rc = vectorOutRowsPut (pRows , i , 0 , & ctx .aTopCandidates [i ]-> nRowid , NULL );
12931361 }else {
1294- rc = diskAnnGetShadowRowKeys (pIndex , ctx .aCandidates [i ]-> nRowid , pKey , pRows , i );
1362+ rc = diskAnnGetShadowRowKeys (pIndex , ctx .aTopCandidates [i ]-> nRowid , pKey , pRows , i );
12951363 }
12961364 if ( rc != SQLITE_OK ){
12971365 * pzErrMsg = sqlite3_mprintf ("vector index(search): failed to put result in the output row" );
@@ -1327,7 +1395,7 @@ int diskAnnInsert(
13271395
13281396 DiskAnnTrace (("diskAnnInset started\n" ));
13291397
1330- rc = diskAnnSearchCtxInit (& ctx , pVectorInRow -> pVector , pIndex -> insertL , DISKANN_BLOB_WRITABLE );
1398+ rc = diskAnnSearchCtxInit (& ctx , pVectorInRow -> pVector , pIndex -> insertL , 1 , DISKANN_BLOB_WRITABLE );
13311399 if ( rc != SQLITE_OK ){
13321400 * pzErrMsg = sqlite3_mprintf ("vector index(insert): failed to initialize search context" );
13331401 return rc ;
0 commit comments