knn_NNDescent.js

import { euclidean } from "../metrics/index.js";
import { Randomizer } from "../util/index.js";
import { Heap } from "../datastructure/index.js";
/**
 * @class
 * @alias NNDescent
 */
export class NNDescent{
    /**
     * @constructor
     * @memberof module:knn
     * @alias NNDescent
     * @param {Array<*>=} elements - called V in paper.
     * @param {Function} [metric = euclidean] - called sigma in paper.
     * @param {Number} [K = 10] - number of neighbors {@link search} should return.
     * @param {Number} [rho = .8] - sample rate.
     * @param {Number} [delta = 0.0001] - precision parameter.
     * @param {Number} [seed = 1987] - seed for the random number generator.
     * @returns {NNDescent}
     * @see {@link http://www.cs.princeton.edu/cass/papers/www11.pdf}
     */
    constructor(elements, metric=euclidean, K = 10, rho = 1, delta = 1e-3, seed=19870307) {
        this._metric = metric;
        this._randomizer = new Randomizer(seed);
        this._K = K;
        this._rho = rho;
        this._sample_size = K * rho;
        this._delta = delta;
        if (elements) {
            this.add(elements)
        }
        return this;   
    }

    /**
     * Samples Array A with sample size.
     * @private
     * @param {Array<*>} A 
     */
    _sample(A) {
        const n = A.length;
        const sample_size = this._sample_size;
        if (sample_size > n) {
            return A;
        } else {
            const randomizer = this._randomizer;
            return randomizer.choice(A, sample_size);
        }
    }

    /**
     * Updates the KNN heap and returns 1 if changed, or 0 if not.
     * @private
     * @param {KNNHeap} B 
     * @param {*} u 
     */
    _update(B, u) {
        if (B.push(u)) {
            u.flag = true;
            B.pop();
            return 1;
        } else {
            return 0;
        }
    }

    /**
     * Collects for each element where it is neighbor from.
     * @private
     * @param {Array<KNNHeap>} B 
     */
    _reverse(B) {
        const N = this._N;
        const R = new Array(N).fill().map(() => new Array());
        for (let i = 0; i < N; ++i) {
            for (let j = 0; j < N; ++j) {
                const Bjdata = B[j].data()
                const val = Bjdata.find(d => d.index === i)
                if (val) R[j].push(val);
            }
        }
        return R;
    }

    /**
     * 
     * @param {Array} elements 
     */
    add(elements) {
        this._elements = elements = elements.map((e, i) => {
            return {
                "element": e,
                "index": i,
                "flag": true,
            }
        })
        const randomizer = this._randomizer;
        const metric = this._metric;
        const K = this._K;
        const delta = this._delta;
        const N = this._N = elements.length;
        const B = this._B = new Array();
        // B[v] <-- Sample(V,K)
        for (let i = 0; i < N; ++i) {
            const e = elements[i];
            const sample = randomizer.choice(elements, K);
            const Bi = new KNNHeap(sample, (d) => metric(d.element, e.element), "max"); // "max" to pop the futherst elements away
            B.push(Bi);
        }

        // loop
        let c = Infinity;
        let old_c = -Infinity;
        //let min_iter = 10;
        //let max_iter = 20;
        //while (min_iter-- > 0 || (c < delta * N * K) && max_iter-- > 0) {
        while (c > (delta * N * K) && c != old_c) {
            // parallel for v e V do
            const old_ = new Array(N);
            const new_ = new Array(N);
            for (let i = 0; i < N; ++i) {
                const e = elements[i];
                const Bi = B[i].data();
                const falseBs = Bi.filter(d => !d.flag);
                const trueBs = this._sample(Bi.filter(d => d.flag));
                trueBs.forEach(d => d.flag = false);
                old_[i] = new KNNHeap(falseBs, (d) => metric(d.element, e.element), "max");
                new_[i] = new KNNHeap(trueBs, (d) => metric(d.element, e.element), "max");
            }
            const old_reverse = this._reverse(old_);
            const new_reverse = this._reverse(new_);
            old_c = c;
            c = 0;
            // parallel for v e V do
            for (let i = 0; i < N; ++i) {
                this._sample(old_reverse[i]).forEach(o => old_[i].push(o));
                this._sample(new_reverse[i]).forEach(n => new_[i].push(n));
                const new_i = new_[i].data();
                const old_i = old_[i].data();
                const n1 = new_i.length;
                const n2 = old_i.length;
                for (let j = 0; j < n1; ++j) {
                    const u1 = new_i[j];
                    const Bu1 = B[u1.index];
                    for (let k = 0; k < n1; ++k) {
                        const u2 = new_i[k];
                        if (u1 == u2) continue;
                        const Bu2 = B[u2.index];
                        c += this._update(Bu2, u1);
                        c += this._update(Bu1, u2);
                    }
                    for (let k = 0; k < n2; ++k) {
                        const u2 = old_i[k];
                        if (u1 == u2) continue;
                        const Bu2 = B[u2.index];
                        c += this._update(Bu2, u1);
                        c += this._update(Bu1, u2);
                    }
                }
            }
        } 
        return this;
    }

    /**
     * @todo not implemented yet
     * @param {*} x 
     * @param {*} k 
     */
    search(x, k=5) {
        return this._B[this._randomizer.random_int % (this._N - 1)].toArray().slice(0, k);
    }

    search_index(i, k=5) {
        const B = this._B[i];
        const result = B.raw_data().sort((a, b) => a.value - b.value).slice(-k);
        return result;
    }
}

class KNNHeap extends Heap{
    constructor(elements, accessor, comparator) {
        super(null, accessor, comparator)
        this.set = new Set();
        if (elements) {
            for (const element of elements) {
                this.push(element);
            }
        }
    }

    push(element) {
        const set = this.set;
        if (set.has(element)){
            return false;
        } else {
            set.add(element);
            super.push(element);
            return true;
        }    
    }

    pop() {
        super.pop().element;
        //const element = super.pop().element;
        // once popped it should not return into the heap.
        // used as max heap. therefore, if popped the furthest 
        // element in the knn list gets removed.
        // this.set.delete(element); 
    }
}