关于索引构建的算法在 [[基于图的 ANNS 构建索引相关]] 中,本文将结合 HNSW 代码深入理解 HNSW 算法构建的细节

# 参数分析

# 构建参数

# 层次选择

lln(unif(0..1))mLl←⌊−ln(unif(0..1))⋅m_L⌋
由于 ln(unif(0..1)) 的值是负数,并且乘以负号后变为正数,再乘以常数 mL ,生成的 l 值服从一种指数分布。层数越高,其出现的概率越小。这确保了大多数元素的层数较低,而极少数元素的层数会较高。
在实际系统中,层数 l 是整数值,并且受到计算机内存和性能的限制,实际中元素的层数不可能无限增长。此外,HNSW 通常设定了一个最大层数的限制,以防止极端情况下层数过高。

# 参数选择

一个简单选择最优 mL 的方法是 1/ln (M),2∙M 是 Mmax0 的一个不错选择
image.png
对于启发式方法:中低维 / 高度聚类数据效果好,均匀和非常高维的数据无效果。

# 用户参数

用户剩下的唯一有意义的构造参数是 M。M 的合理范围是从 5 到 48。模拟结果表明,对于较低的召回率和 / 或较低维度的数据,较小的 M 通常可以产生更好的结果,而较大的 M 则更适合于高召回率和 / 或高维度的数据(见图 8 进行说明,Core i5 2400 CPU)。该参数还定义了算法的内存消耗(与 M 成正比),因此应谨慎选择。
image.png

# 复杂度分析

# 时间复杂度

层内搜索的跳数S=1/(1exp(mL))S=1/(1-exp(-m_L)) 与数据集规模无关
假设节点平均度数也为常数 C, 距离计算次数 CS
对于低维数据,最大层索引O(log(N))O(log(N))

# 空间复杂度

每个元素的连接数量为零层的Mmax0M_{max0} 和其他层的MmaxM_{max}。因此,每个元素的平均内存消耗为(Mmax0+mLMmax)bytes_per_link(M_{max0}+m_L ∙M_{max})∙bytes\_per\_link

# 进一步

失去了分布式搜索的可能性。HNSW 结构中的搜索始终从顶层开始,因此该结构无法使用之前的相同技术进行分布式处理

# 源码分析

# 数据结构

# labeltype 和 tableint

  • labeltype 是外部标签的类型,可能是 intstring 等类型,用于表示数据点的唯一外部标识。
  • tableint 是节点的内部 ID 类型,在 HNSW 中定义为 unsigned int ,用于在图结构中唯一标识每个节点。
    • labeltype :用于表示节点的外部标签,用户通过它来引用和操作数据点。在外部系统中,用户通常使用 labeltype 来查找或标识节点。
  • tableint :用于表示节点的内部 ID,是在图中用于定位节点的唯一标识符。这个值是程序内部管理节点使用的,而不直接暴露给外部用户。

# 内存中节点

917111494201114222.jpg

# 插入点

tableint addPoint(const void *data_point, labeltype label, int level) {
        tableint cur_c = 0;
        {
            // Checking if the element with the same label already exists
            // if so, updating it *instead* of creating a new element.
            std::unique_lock <std::mutex> lock_table(label_lookup_lock);
            auto search = label_lookup_.find(label);
            if (search != label_lookup_.end()) {// 找到了相同的节点
                tableint existingInternalId = search->second;
                if (allow_replace_deleted_) {
                    if (isMarkedDeleted(existingInternalId)) {
                        throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
                    }
                }
                lock_table.unlock();
                if (isMarkedDeleted(existingInternalId)) {
                    unmarkDeletedInternal(existingInternalId);
                }// 取消删除标记
                updatePoint(data_point, existingInternalId, 1.0);
                // 通过 updatePoint 更新该节点的数据
                return existingInternalId;
            }
            if (cur_element_count >= max_elements_) {
                throw std::runtime_error("The number of elements exceeds the specified limit");
            }
            cur_c = cur_element_count;
            cur_element_count++;
            label_lookup_[label] = cur_c;
            // 函数生成一个新的 tableint 值 cur_c 作为新节点的内部 ID,并在 label_lookup_ 中建立 label 与 cur_c 的映射。
        }
        std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
        int curlevel = getRandomLevel(mult_);
        if (level > 0)
            curlevel = level;
        element_levels_[cur_c] = curlevel;
        // 随机生成新节点的层数(或使用指定的层数),并将其记录在 element_levels_ 中。
        std::unique_lock <std::mutex> templock(global);
        int maxlevelcopy = maxlevel_;
        if (curlevel <= maxlevelcopy)
            templock.unlock();
        tableint currObj = enterpoint_node_;
        tableint enterpoint_copy = enterpoint_node_;
        memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
        // Initialisation of the data and label
        memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
        memcpy(getDataByInternalId(cur_c), data_point, data_size_);
        if (curlevel) {
            linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
            if (linkLists_[cur_c] == nullptr)
                throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
            memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
        }
        if ((signed)currObj != -1) {
            if (curlevel < maxlevelcopy) {
                dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
                for (int level = maxlevelcopy; level > curlevel; level--) {
                    bool changed = true;
                    while (changed) {// 不断更新当前最接近的数据点 currObj,直到找不到更接近的节点为止
                        changed = false;
                        unsigned int *data;
                        std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
                        data = get_linklist(currObj, level);
                        int size = getListCount(data);
                        // 获取当前节点在指定层的邻居列表,size 是邻居列表的大小。
                        tableint *datal = (tableint *) (data + 1);
                        for (int i = 0; i < size; i++) {
                            tableint cand = datal[i];
                            if (cand < 0 || cand > max_elements_)
                                throw std::runtime_error("cand error");
                            dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);// 距离计算
                            if (d < curdist) {
                                curdist = d;
                                currObj = cand;
                                changed = true;
                            }
                        }
                    }
                }
            }//todo
            bool epDeleted = isMarkedDeleted(enterpoint_copy);
            for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
                if (level > maxlevelcopy || level < 0)  // possible?
                    throw std::runtime_error("Level error");
                std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
                        currObj, data_point, level);
                        //searchBaseLayer 函数会在指定的 level 层进行搜索,寻找与当前节点 data_point 最近的候选邻居节点,返回一个优先队列 top_candidates,队列中的元素按照距离排序,距离越近的节点排在前面。
                        // 和论文的算法不一样,这里的 ep:currObj 仍然只有一个
                if (epDeleted) {
                    top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
                    if (top_candidates.size() > ef_construction_)
                        top_candidates.pop();
                }// 如果入口点 enterpoint_copy 被标记为已删除,则将其作为候选节点之一加入到 top_candidates 中,因为入口节点很重要?
                currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
                // 在 level 层将新插入节点 cur_c 与 top_candidates 中的候选节点相互连接(双向连接)。
            }
        } else {
            // Do nothing for the first element
            enterpoint_node_ = 0;
            maxlevel_ = curlevel;
        }
        // Releasing lock for the maximum level
        if (curlevel > maxlevelcopy) {
            enterpoint_node_ = cur_c;
            maxlevel_ = curlevel;
        }
        return cur_c;
    }

# 选择邻居以及更新邻接表

p
tableint  mutuallyConnectNewElement(
        const void *data_point,
        tableint cur_c,
        std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
        int level,
        bool isUpdate) {
        size_t Mcurmax = level ? maxM_ : maxM0_;//Mcurmax: 决定了当前层的最大邻居数量。如果在第 0 层,邻居数上限为 maxM0_,否则为 maxM_。
        getNeighborsByHeuristic2(top_candidates, M_);// 使用一种启发式方法从 top_candidates 中选择 M_个最优的邻居节点
        if (top_candidates.size() > M_)
            throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
        std::vector<tableint> selectedNeighbors;
        selectedNeighbors.reserve(M_);
        while (top_candidates.size() > 0) {
            selectedNeighbors.push_back(top_candidates.top().second);
            top_candidates.pop();
        }
        tableint next_closest_entry_point = selectedNeighbors.back();
        // 更新 cur 的邻接表
        {
            // lock only during the update
            // because during the addition the lock for cur_c is already acquired
            std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
            if (isUpdate) {
                lock.lock();
            }
            linklistsizeint *ll_cur;
            if (level == 0)
                ll_cur = get_linklist0(cur_c);
            else
                ll_cur = get_linklist(cur_c, level);
            if (*ll_cur && !isUpdate) {
                throw std::runtime_error("The newly inserted element should have blank link list");
            }
            setListCount(ll_cur, selectedNeighbors.size());
            tableint *data = (tableint *) (ll_cur + 1);
            for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
                if (data[idx] && !isUpdate)
                    throw std::runtime_error("Possible memory corruption");
                if (level > element_levels_[selectedNeighbors[idx]])
                    throw std::runtime_error("Trying to make a link on a non-existent level");
                data[idx] = selectedNeighbors[idx];
            }// 将 selectedNeighbors 中的节点 ID 填充到当前节点的邻居列表中
        }
        // 更新邻居节点的邻接表
        for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
            std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
            linklistsizeint *ll_other;
            if (level == 0)
                ll_other = get_linklist0(selectedNeighbors[idx]);
            else
                ll_other = get_linklist(selectedNeighbors[idx], level);
            size_t sz_link_list_other = getListCount(ll_other);
            if (sz_link_list_other > Mcurmax)
                throw std::runtime_error("Bad value of sz_link_list_other");
            if (selectedNeighbors[idx] == cur_c)
                throw std::runtime_error("Trying to connect an element to itself");
            if (level > element_levels_[selectedNeighbors[idx]])
                throw std::runtime_error("Trying to make a link on a non-existent level");
            tableint *data = (tableint *) (ll_other + 1);
            bool is_cur_c_present = false;
            if (isUpdate) {
                for (size_t j = 0; j < sz_link_list_other; j++) {
                    if (data[j] == cur_c) {
                        is_cur_c_present = true;
                        break;
                    }
                }
            }
            // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
            if (!is_cur_c_present) {
                if (sz_link_list_other < Mcurmax) {// 没有超过最大邻居数量
                    data[sz_link_list_other] = cur_c;
                    setListCount(ll_other, sz_link_list_other + 1);
                } else {
                    // finding the "weakest" element to replace it with the new one
                    dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
                                                dist_func_param_);
                    // Heuristic:
                    std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
                    candidates.emplace(d_max, cur_c);
                    for (size_t j = 0; j < sz_link_list_other; j++) {
                        candidates.emplace(
                                fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
                                                dist_func_param_), data[j]);
                    }
                    getNeighborsByHeuristic2(candidates, Mcurmax);
                    // 将 cur 和已有节点 与这个 neighbor 的距离计算出来再次启发式筛选
                    int indx = 0;
                    while (candidates.size() > 0) {
                        data[indx] = candidates.top().second;
                        candidates.pop();
                        indx++;
                    }// 重新填入邻接表
                    setListCount(ll_other, indx);
                    // Nearest K:
                    /*int indx = -1;
                    for (int j = 0; j < sz_link_list_other; j++) {
                        dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
                        if (d > d_max) {
                            indx = j;
                            d_max = d;
                        }
                    }
                    if (indx >= 0) {
                        data[indx] = cur_c;
                    } */
                }
            }
        }
        return next_closest_entry_point;
    }

# 启发式选择邻居

void getNeighborsByHeuristic2(
        std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
        const size_t M) {
        if (top_candidates.size() < M) {
            return;
        }
        std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
        std::vector<std::pair<dist_t, tableint>> return_list;
        while (top_candidates.size() > 0) {
            queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
            top_candidates.pop();
        }// 新的优先队列保持距离从远到近的顺序
        while (queue_closest.size()) {
            if (return_list.size() >= M)
                break;
            std::pair<dist_t, tableint> curent_pair = queue_closest.top();
            dist_t dist_to_query = -curent_pair.first;
            queue_closest.pop();
            bool good = true;
            for (std::pair<dist_t, tableint> second_pair : return_list) {
                dist_t curdist =
                        fstdistfunc_(getDataByInternalId(second_pair.second),
                                        getDataByInternalId(curent_pair.second),
                                        dist_func_param_);// 距离计算
                if (curdist < dist_to_query) {
                    good = false;
                    break;
                }
            }
            if (good) {
                return_list.push_back(curent_pair);
            }
        }// 这里没有论文中补充至 M 个的步骤,也就是说会少于 M 个
        for (std::pair<dist_t, tableint> curent_pair : return_list) {
            top_candidates.emplace(-curent_pair.first, curent_pair.second);
        }
    }

# 参考

HNSW 算法原理与源码解读_hnsw 有多少层 - CSDN 博客
nmslib/hnswlib: Header-only C++/python library for fast approximate nearest neighbors (github.com)

此文章已被阅读次数:正在加载...更新于

请我喝[茶]~( ̄▽ ̄)~*

GuitarYui 微信支付

微信支付

GuitarYui 支付宝

支付宝