src/jdk.jfr/share/classes/jdk/jfr/internal/util/PerfectHashMap.java
author mgronlun
Fri, 17 May 2019 18:03:14 +0200
branchJEP-349-branch
changeset 57361 53dccc90a5be
permissions -rw-r--r--
Preview-addendum

package jdk.jfr.internal.util;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

public class PerfectHashMap<V> {
    private static final long COLLISION_SHIFT = 63;
    private static final long COLLISION_BIT = 1L << COLLISION_SHIFT;
    private static final long COLLISION_MASK = COLLISION_BIT - 1;
    private static final int MAX_REMAP_ATTEMPTS = 100000;
    private static final int  MAX_ATTEMPS_BEFORE_RESIZE = 100;

    static final long W = 64L;
    static class LinkedValue<V> {
        final V value;
        long next;

        LinkedValue(V value) {
            this.value = value;
            this.next = 0;
        }
    }

    private UniversalHashFamily hashFamily = new UniversalHashFamily();
    private PrimitiveHashMap<LinkedValue<V>> loadMap;
    private Object[] valueTable;
    private long[] routeTable;
    private long shift;
    private long shiftMask;
    private int tableLengthMask;
    private long primaryHashFunction = 0;
    private int collisions = 0;
    private int retries = 0;
    private int sizeFactor = 1;
    private boolean minimal;

    public V get(long key) {
        LinkedValue<V> v = loadMap.get(key);
        return v != null ? v.value : null;
    }

    public V put(long key, V value) {
        LinkedValue<V> existing = loadMap.put(key, new LinkedValue<V>(value));
        return existing != null ? existing.value : null;
    }

    public void forEach(BiConsumer<? super Long, ? super V> action) {
        //loadMap.forEach(PerfectHashMap<V>::callback);
    }

    public final void forEach(Consumer<? super V> action) {
        //loadMap.forEach(action);
    }

    public final long[] keys() {
        return loadMap.keys();
    }

    static class Log2 {
        private static final int MAX_SIZE_EXPONENT = 32;

        static long log2base10(long exponent) {
            return 1L << exponent;
        }

        static int log2(int value) {
            int i = 0;
            int lastMultiple = 0;
            while (i < MAX_SIZE_EXPONENT) {
                int multiple = (int)log2base10(i);
                if ((value & multiple) != 0) {
                    lastMultiple = i;
                }
                ++i;
            }
            return ((int)log2base10(lastMultiple) ^ value) != 0 ? lastMultiple + 1 : lastMultiple;
        }
            }

    static final int tableExponent(int cap) {
        return Log2.log2(cap);
    }

    PerfectHashMap() {
        this(false, 101);
    }

    PerfectHashMap(int size) {
        this(false, size);
    }

    PerfectHashMap(boolean minimal, int size) {
        this.minimal = minimal;
        this.loadMap = new PrimitiveHashMap<>(size);
        this.primaryHashFunction = hashFamily.getRandomHashFunction();
    }

    @SuppressWarnings("unchecked")
    public V getPerfect(long key) {
        int routeIdx = getIndex(key, primaryHashFunction);
        assert(routeIdx >= 0);
        assert(routeIdx < routeTable.length);
        long element = routeTable[routeIdx];
        int valueIdx = element < 0 ? getIndex(key, -element - 1) : (int)element;
        assert(valueIdx >= 0);
        assert(valueIdx < valueTable.length);
        return (V)valueTable[valueIdx];
    }

    private long getRandomHashFunction() {
        return hashFamily.getRandomHashFunction();
    }
    private int getIndex(long key, long hashFunction) {
       final int idx = UniversalHashFamily.getIndex(key, hashFunction, shift, shiftMask);
       assert(idx >= 0);
       assert(idx < routeTable.length);
       return idx;
    }
    private static boolean isColliding(long entry) {
        return entry < 0;
    }
    private boolean isNonColliding(long entry) {
        return entry > 0;
    }
    private static long setColliding(long entry) {
        return entry | COLLISION_BIT;
    }
    private static long read(long entry) {
        return entry & COLLISION_MASK;
    }

    private int nextValueTableSlot(int lastIdx) {
        assert(lastIdx < valueTable.length);
        int i = lastIdx;
        for (; i < valueTable.length; ++i) {
            if (valueTable[i] == null) {
                break;
            }
        }
        return i;
    }

    private int valueTableStore(V value, int lastIdx) {
        if (lastIdx > valueTable.length) {
            lastIdx = 0;
        }
        assert(lastIdx < valueTable.length);
        final int idx = nextValueTableSlot(lastIdx);
        assert(idx < valueTable.length);
        assert(valueTable[idx] == null);
        valueTable[idx] = value;
        return idx;
    }


    private void routeNonCollisions() {
        int lastIdx = 0;
        for (int i = 0; i < routeTable.length; ++i) {
            if (isNonColliding(routeTable[i])) {
                lastIdx = valueTableStore(loadMap.get(routeTable[i]).value, lastIdx);
                routeTable[i] = lastIdx++;
           }
        }
    }

    private void rollback(int idx, int length, long hashFunction) {
        assert(isColliding(routeTable[idx]));
        long key = read(routeTable[idx]);
        LinkedValue<V> v = loadMap.get(key); // boxing
        for (int i = 0; i < length; ++i) {
            final int valueIdx = getIndex(key, hashFunction);
            assert(valueIdx >= 0);
            assert(valueIdx < valueTable.length);
            assert(valueTable[valueIdx] != null);
            valueTable[valueIdx] = null;
            key = v.next;
            v = loadMap.get(v.next); // no boxing
        }
    }

    private boolean remap(int idx, long hashFunction) {
        assert(isColliding(routeTable[idx]));
        int completed = 0;
        long key = read(routeTable[idx]);
        LinkedValue<V> v = loadMap.get(key);
        while (key != 0) {
            final int valueIdx = getIndex(key, hashFunction);
            assert(valueIdx >= 0);
            assert(valueIdx < valueTable.length);
            if (valueTable[valueIdx] == null) {
                valueTable[valueIdx] = v.value;
                ++completed;
                key = v.next;
                v = loadMap.get(v.next);
                continue;
            }
            rollback(idx, completed, hashFunction);
            return false;
        }
        return true;
    }

    private boolean routeCollisions(int idx) {
        assert(isColliding(routeTable[idx]));
        boolean success = false;
        int attempts = 0;
        long randomHashFunction = 0;
        do {
            randomHashFunction = getRandomHashFunction();
            success = remap(idx, randomHashFunction);
            if (++attempts == MAX_REMAP_ATTEMPTS) {
                System.out.println("Failed number of attempts - restart: " + attempts);
                return false;
            }
        } while (!success);
        System.out.println("Number of remap attempts: " + attempts);
        routeTable[idx] = -1 - randomHashFunction;
        assert(-routeTable[idx] - 1 == randomHashFunction);
        return true;
    }


    private boolean routeCollisions() {
        for (int i = 0; i < routeTable.length; ++i) {
            if (isColliding(routeTable[i])) {
                if (!routeCollisions(i)) {
                    return false;
                }
            }
        }
        return true;
    }

    private static void clearLongTable(long[] table) {
        Arrays.fill(table, 0);
        for (int i = 0; i < table.length; ++i) {
            assert(table[i] == 0);
        }
    }

    private static <T extends Object> void clearReferenceTable(T[] table) {
        Arrays.fill(table, null);
        for (int i = 0; i < table.length; ++i) {
            assert(table[i] == null);
        }
    }

    private void unlinkChains() {
        for (long key : loadMap.keys()) {
            loadMap.get(key).next = 0;
        }
    }

    private void routeTableStore(long key, LinkedValue<V> value, int idx) {
        assert(idx >= 0);
        assert(idx < routeTable.length);
        long existing = read(routeTable[idx]);
        if (existing == 0) {
            routeTable[idx] = key;
            return;
        }
        ++collisions;
        routeTable[idx] = setColliding(existing);
        LinkedValue<V> existingValue = loadMap.get(existing);
        value.next = existingValue.next;
        existingValue.next = key;
    }

    private void mapKeys() {
        for (long key : loadMap.keys()) {
            routeTableStore(key, loadMap.get(key), getIndex(key, primaryHashFunction));
        }
    }

    private void validate() {
        for (long key : loadMap.keys()) {
            long element = routeTable[getIndex(key, primaryHashFunction)];
            int valueIdx = element < 0 ? getIndex(key, -element - 1) : (int)element;
            assert(valueIdx >= 0);
            assert(loadMap.get(key) == valueTable[valueIdx]);
        }
    }

    private void reset() {
        collisions = 0;
        clearLongTable(routeTable);
        clearReferenceTable(valueTable);
    }

    private int dimensionTableSize() {
        int size = loadMap.size() * sizeFactor;
        return (int)Log2.log2base10(Log2.log2(size));
    }

    @SuppressWarnings({"rawtypes","unchecked"})
    private void allocateTables() {
        int size = dimensionTableSize();
        this.tableLengthMask = size - 1;
        this.shift = W - tableExponent(size);
        this.shiftMask = Log2.log2base10(shift) - 1;
        routeTable = new long[size];
        valueTable = (V[])new Object[size];
        collisions = 0;
        retries = 0;
    }

    public void build() {
        start:
        while (true) {
            allocateTables();
            System.out.println("Table size " + routeTable.length);
            mapKeys();
            if (collisions > 0) {
                if (!routeCollisions()) {
                    unlinkChains();
                    if (++retries <= MAX_ATTEMPS_BEFORE_RESIZE) {
                      reset();
                    } else {
                      sizeFactor *= 2;
                    }
                    continue start;
                }
            }
            routeNonCollisions();
            return;
        }
    }

    public void rebuild() {
        sizeFactor = 1;
        build();
    }
    public int size() {
        return loadMap.size();
    }
}