关于索引构建的算法在 [[基于图的 ANNS 构建索引相关]] 中,本文将结合 HNSW 代码深入理解 HNSW 算法构建的细节
# 参数分析
# 构建参数
# 层次选择
由于 ln(unif(0..1))
的值是负数,并且乘以负号后变为正数,再乘以常数 mL
,生成的 l
值服从一种指数分布。层数越高,其出现的概率越小。这确保了大多数元素的层数较低,而极少数元素的层数会较高。
在实际系统中,层数 l
是整数值,并且受到计算机内存和性能的限制,实际中元素的层数不可能无限增长。此外,HNSW 通常设定了一个最大层数的限制,以防止极端情况下层数过高。
# 参数选择
一个简单选择最优 mL 的方法是 1/ln (M),2∙M 是 Mmax0 的一个不错选择
对于启发式方法:中低维 / 高度聚类数据效果好,均匀和非常高维的数据无效果。
# 用户参数
用户剩下的唯一有意义的构造参数是 M。M 的合理范围是从 5 到 48。模拟结果表明,对于较低的召回率和 / 或较低维度的数据,较小的 M 通常可以产生更好的结果,而较大的 M 则更适合于高召回率和 / 或高维度的数据(见图 8 进行说明,Core i5 2400 CPU)。该参数还定义了算法的内存消耗(与 M 成正比),因此应谨慎选择。
# 复杂度分析
# 时间复杂度
层内搜索的跳数 与数据集规模无关
假设节点平均度数也为常数 C, 距离计算次数 CS
对于低维数据,最大层索引
# 空间复杂度
每个元素的连接数量为零层的 和其他层的。因此,每个元素的平均内存消耗为
# 进一步
失去了分布式搜索的可能性。HNSW 结构中的搜索始终从顶层开始,因此该结构无法使用之前的相同技术进行分布式处理
# 源码分析
# 数据结构
# labeltype 和 tableint
labeltype
是外部标签的类型,可能是int
、string
等类型,用于表示数据点的唯一外部标识。tableint
是节点的内部 ID 类型,在 HNSW 中定义为unsigned int
,用于在图结构中唯一标识每个节点。-
-
labeltype
:用于表示节点的外部标签,用户通过它来引用和操作数据点。在外部系统中,用户通常使用labeltype
来查找或标识节点。
-
-
tableint
:用于表示节点的内部 ID,是在图中用于定位节点的唯一标识符。这个值是程序内部管理节点使用的,而不直接暴露给外部用户。
# 内存中节点
# 插入点
tableint addPoint(const void *data_point, labeltype label, int level) { | |
tableint cur_c = 0; | |
{ | |
// Checking if the element with the same label already exists | |
// if so, updating it *instead* of creating a new element. | |
std::unique_lock <std::mutex> lock_table(label_lookup_lock); | |
auto search = label_lookup_.find(label); | |
if (search != label_lookup_.end()) {// 找到了相同的节点 | |
tableint existingInternalId = search->second; | |
if (allow_replace_deleted_) { | |
if (isMarkedDeleted(existingInternalId)) { | |
throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); | |
} | |
} | |
lock_table.unlock(); | |
if (isMarkedDeleted(existingInternalId)) { | |
unmarkDeletedInternal(existingInternalId); | |
}// 取消删除标记 | |
updatePoint(data_point, existingInternalId, 1.0); | |
// 通过 updatePoint 更新该节点的数据 | |
return existingInternalId; | |
} | |
if (cur_element_count >= max_elements_) { | |
throw std::runtime_error("The number of elements exceeds the specified limit"); | |
} | |
cur_c = cur_element_count; | |
cur_element_count++; | |
label_lookup_[label] = cur_c; | |
// 函数生成一个新的 tableint 值 cur_c 作为新节点的内部 ID,并在 label_lookup_ 中建立 label 与 cur_c 的映射。 | |
} | |
std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]); | |
int curlevel = getRandomLevel(mult_); | |
if (level > 0) | |
curlevel = level; | |
element_levels_[cur_c] = curlevel; | |
// 随机生成新节点的层数(或使用指定的层数),并将其记录在 element_levels_ 中。 | |
std::unique_lock <std::mutex> templock(global); | |
int maxlevelcopy = maxlevel_; | |
if (curlevel <= maxlevelcopy) | |
templock.unlock(); | |
tableint currObj = enterpoint_node_; | |
tableint enterpoint_copy = enterpoint_node_; | |
memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); | |
// Initialisation of the data and label | |
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); | |
memcpy(getDataByInternalId(cur_c), data_point, data_size_); | |
if (curlevel) { | |
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); | |
if (linkLists_[cur_c] == nullptr) | |
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); | |
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); | |
} | |
if ((signed)currObj != -1) { | |
if (curlevel < maxlevelcopy) { | |
dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); | |
for (int level = maxlevelcopy; level > curlevel; level--) { | |
bool changed = true; | |
while (changed) {// 不断更新当前最接近的数据点 currObj,直到找不到更接近的节点为止 | |
changed = false; | |
unsigned int *data; | |
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]); | |
data = get_linklist(currObj, level); | |
int size = getListCount(data); | |
// 获取当前节点在指定层的邻居列表,size 是邻居列表的大小。 | |
tableint *datal = (tableint *) (data + 1); | |
for (int i = 0; i < size; i++) { | |
tableint cand = datal[i]; | |
if (cand < 0 || cand > max_elements_) | |
throw std::runtime_error("cand error"); | |
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);// 距离计算 | |
if (d < curdist) { | |
curdist = d; | |
currObj = cand; | |
changed = true; | |
} | |
} | |
} | |
} | |
}//todo | |
bool epDeleted = isMarkedDeleted(enterpoint_copy); | |
for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { | |
if (level > maxlevelcopy || level < 0) // possible? | |
throw std::runtime_error("Level error"); | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer( | |
currObj, data_point, level); | |
//searchBaseLayer 函数会在指定的 level 层进行搜索,寻找与当前节点 data_point 最近的候选邻居节点,返回一个优先队列 top_candidates,队列中的元素按照距离排序,距离越近的节点排在前面。 | |
// 和论文的算法不一样,这里的 ep:currObj 仍然只有一个 | |
if (epDeleted) { | |
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); | |
if (top_candidates.size() > ef_construction_) | |
top_candidates.pop(); | |
}// 如果入口点 enterpoint_copy 被标记为已删除,则将其作为候选节点之一加入到 top_candidates 中,因为入口节点很重要? | |
currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); | |
// 在 level 层将新插入节点 cur_c 与 top_candidates 中的候选节点相互连接(双向连接)。 | |
} | |
} else { | |
// Do nothing for the first element | |
enterpoint_node_ = 0; | |
maxlevel_ = curlevel; | |
} | |
// Releasing lock for the maximum level | |
if (curlevel > maxlevelcopy) { | |
enterpoint_node_ = cur_c; | |
maxlevel_ = curlevel; | |
} | |
return cur_c; | |
} |
# 选择邻居以及更新邻接表
tableint mutuallyConnectNewElement( | |
const void *data_point, | |
tableint cur_c, | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, | |
int level, | |
bool isUpdate) { | |
size_t Mcurmax = level ? maxM_ : maxM0_;//Mcurmax: 决定了当前层的最大邻居数量。如果在第 0 层,邻居数上限为 maxM0_,否则为 maxM_。 | |
getNeighborsByHeuristic2(top_candidates, M_);// 使用一种启发式方法从 top_candidates 中选择 M_个最优的邻居节点 | |
if (top_candidates.size() > M_) | |
throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); | |
std::vector<tableint> selectedNeighbors; | |
selectedNeighbors.reserve(M_); | |
while (top_candidates.size() > 0) { | |
selectedNeighbors.push_back(top_candidates.top().second); | |
top_candidates.pop(); | |
} | |
tableint next_closest_entry_point = selectedNeighbors.back(); | |
// 更新 cur 的邻接表 | |
{ | |
// lock only during the update | |
// because during the addition the lock for cur_c is already acquired | |
std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock); | |
if (isUpdate) { | |
lock.lock(); | |
} | |
linklistsizeint *ll_cur; | |
if (level == 0) | |
ll_cur = get_linklist0(cur_c); | |
else | |
ll_cur = get_linklist(cur_c, level); | |
if (*ll_cur && !isUpdate) { | |
throw std::runtime_error("The newly inserted element should have blank link list"); | |
} | |
setListCount(ll_cur, selectedNeighbors.size()); | |
tableint *data = (tableint *) (ll_cur + 1); | |
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { | |
if (data[idx] && !isUpdate) | |
throw std::runtime_error("Possible memory corruption"); | |
if (level > element_levels_[selectedNeighbors[idx]]) | |
throw std::runtime_error("Trying to make a link on a non-existent level"); | |
data[idx] = selectedNeighbors[idx]; | |
}// 将 selectedNeighbors 中的节点 ID 填充到当前节点的邻居列表中 | |
} | |
// 更新邻居节点的邻接表 | |
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { | |
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]); | |
linklistsizeint *ll_other; | |
if (level == 0) | |
ll_other = get_linklist0(selectedNeighbors[idx]); | |
else | |
ll_other = get_linklist(selectedNeighbors[idx], level); | |
size_t sz_link_list_other = getListCount(ll_other); | |
if (sz_link_list_other > Mcurmax) | |
throw std::runtime_error("Bad value of sz_link_list_other"); | |
if (selectedNeighbors[idx] == cur_c) | |
throw std::runtime_error("Trying to connect an element to itself"); | |
if (level > element_levels_[selectedNeighbors[idx]]) | |
throw std::runtime_error("Trying to make a link on a non-existent level"); | |
tableint *data = (tableint *) (ll_other + 1); | |
bool is_cur_c_present = false; | |
if (isUpdate) { | |
for (size_t j = 0; j < sz_link_list_other; j++) { | |
if (data[j] == cur_c) { | |
is_cur_c_present = true; | |
break; | |
} | |
} | |
} | |
// If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. | |
if (!is_cur_c_present) { | |
if (sz_link_list_other < Mcurmax) {// 没有超过最大邻居数量 | |
data[sz_link_list_other] = cur_c; | |
setListCount(ll_other, sz_link_list_other + 1); | |
} else { | |
// finding the "weakest" element to replace it with the new one | |
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), | |
dist_func_param_); | |
// Heuristic: | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; | |
candidates.emplace(d_max, cur_c); | |
for (size_t j = 0; j < sz_link_list_other; j++) { | |
candidates.emplace( | |
fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), | |
dist_func_param_), data[j]); | |
} | |
getNeighborsByHeuristic2(candidates, Mcurmax); | |
// 将 cur 和已有节点 与这个 neighbor 的距离计算出来再次启发式筛选 | |
int indx = 0; | |
while (candidates.size() > 0) { | |
data[indx] = candidates.top().second; | |
candidates.pop(); | |
indx++; | |
}// 重新填入邻接表 | |
setListCount(ll_other, indx); | |
// Nearest K: | |
/*int indx = -1; | |
for (int j = 0; j < sz_link_list_other; j++) { | |
dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); | |
if (d > d_max) { | |
indx = j; | |
d_max = d; | |
} | |
} | |
if (indx >= 0) { | |
data[indx] = cur_c; | |
} */ | |
} | |
} | |
} | |
return next_closest_entry_point; | |
} |
# 启发式选择邻居
void getNeighborsByHeuristic2( | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, | |
const size_t M) { | |
if (top_candidates.size() < M) { | |
return; | |
} | |
std::priority_queue<std::pair<dist_t, tableint>> queue_closest; | |
std::vector<std::pair<dist_t, tableint>> return_list; | |
while (top_candidates.size() > 0) { | |
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); | |
top_candidates.pop(); | |
}// 新的优先队列保持距离从远到近的顺序 | |
while (queue_closest.size()) { | |
if (return_list.size() >= M) | |
break; | |
std::pair<dist_t, tableint> curent_pair = queue_closest.top(); | |
dist_t dist_to_query = -curent_pair.first; | |
queue_closest.pop(); | |
bool good = true; | |
for (std::pair<dist_t, tableint> second_pair : return_list) { | |
dist_t curdist = | |
fstdistfunc_(getDataByInternalId(second_pair.second), | |
getDataByInternalId(curent_pair.second), | |
dist_func_param_);// 距离计算 | |
if (curdist < dist_to_query) { | |
good = false; | |
break; | |
} | |
} | |
if (good) { | |
return_list.push_back(curent_pair); | |
} | |
}// 这里没有论文中补充至 M 个的步骤,也就是说会少于 M 个 | |
for (std::pair<dist_t, tableint> curent_pair : return_list) { | |
top_candidates.emplace(-curent_pair.first, curent_pair.second); | |
} | |
} |
# 参考
HNSW 算法原理与源码解读_hnsw 有多少层 - CSDN 博客
nmslib/hnswlib: Header-only C++/python library for fast approximate nearest neighbors (github.com)