Java Code Examples for com.jstarcraft.ai.environment.EnvironmentContext#getContext()

The following examples show how to use com.jstarcraft.ai.environment.EnvironmentContext#getContext() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: HMMModel.java    From jstarcraft-rns with Apache License 2.0 6 votes vote down vote up
@Override
protected void eStep() {
    EnvironmentContext context = EnvironmentContext.getContext();
    // 并发计算
    CountDownLatch latch = new CountDownLatch(userSize);
    for (int userIndex = 0; userIndex < userSize; userIndex++) {
        int user = userIndex;
        context.doAlgorithmByAny(userIndex, () -> {
            calculateGammaRho(user, dataMatrixes[user]);
            latch.countDown();
        });
    }
    try {
        latch.await();
    } catch (Exception exception) {
        throw new ModelException(exception);
    }
}
 
Example 2
Source File: AbstractModel.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
@Override
public final void practice() {
    EnvironmentContext context = EnvironmentContext.getContext();
    context.doAlgorithmByEvery(this::constructEnvironment);
    doPractice();
    context.doAlgorithmByEvery(this::destructEnvironment);
}
 
Example 3
Source File: MathCorrelation.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
/**
 * 根据分数矩阵计算相关度
 * 
 * @param scoreMatrix
 * @param transpose
 * @param monitor
 */
default void calculateCoefficients(MathMatrix scoreMatrix, boolean transpose, CorrelationMonitor monitor) {
    EnvironmentContext context = EnvironmentContext.getContext();
    Semaphore semaphore = new Semaphore(0);
    int count = transpose ? scoreMatrix.getColumnSize() : scoreMatrix.getRowSize();
    for (int leftIndex = 0; leftIndex < count; leftIndex++) {
        MathVector thisVector = transpose ? scoreMatrix.getColumnVector(leftIndex) : scoreMatrix.getRowVector(leftIndex);
        if (thisVector.getElementSize() == 0) {
            continue;
        }
        monitor.notifyCoefficientCalculated(leftIndex, leftIndex, getIdentical());
        // user/item itself exclusive
        int permits = 0;
        for (int rightIndex = leftIndex + 1; rightIndex < count; rightIndex++) {
            MathVector thatVector = transpose ? scoreMatrix.getColumnVector(rightIndex) : scoreMatrix.getRowVector(rightIndex);
            if (thatVector.getElementSize() == 0) {
                continue;
            }
            int leftCursor = leftIndex;
            int rightCursor = rightIndex;
            context.doAlgorithmByAny(leftIndex * rightIndex, () -> {
                float coefficient = getCoefficient(thisVector, thatVector);
                if (!Float.isNaN(coefficient)) {
                    monitor.notifyCoefficientCalculated(leftCursor, rightCursor, coefficient);
                }
                semaphore.release();
            });
            permits++;
        }
        try {
            semaphore.acquire(permits);
        } catch (Exception exception) {
            throw new RuntimeException(exception);
        }
    }
}
 
Example 4
Source File: AssociationRuleModel.java    From jstarcraft-rns with Apache License 2.0 4 votes vote down vote up
@Override
protected void doPractice() {
    EnvironmentContext context = EnvironmentContext.getContext();
    Semaphore semaphore = new Semaphore(0);
    // simple rule: X => Y, given that each user vector is regarded as a
    // transaction
    for (int leftItemIndex = 0; leftItemIndex < itemSize; leftItemIndex++) {
        // all transactions for item itemIdx
        SparseVector leftVector = scoreMatrix.getColumnVector(leftItemIndex);
        for (int rightItemIndex = leftItemIndex + 1; rightItemIndex < itemSize; rightItemIndex++) {
            SparseVector rightVector = scoreMatrix.getColumnVector(rightItemIndex);
            int leftIndex = leftItemIndex;
            int rightIndex = rightItemIndex;
            context.doAlgorithmByAny(leftItemIndex * rightItemIndex, () -> {
                int leftCursor = 0, rightCursor = 0, leftSize = leftVector.getElementSize(), rightSize = rightVector.getElementSize();
                if (leftSize != 0 && rightSize != 0) {
                    // compute confidence where containing item assoItemIdx
                    // among
                    // userRatingsVector
                    int count = 0;
                    Iterator<VectorScalar> leftIterator = leftVector.iterator();
                    Iterator<VectorScalar> rightIterator = rightVector.iterator();
                    VectorScalar leftTerm = leftIterator.next();
                    VectorScalar rightTerm = rightIterator.next();
                    // 判断两个有序数组中是否存在相同的数字
                    while (leftCursor < leftSize && rightCursor < rightSize) {
                        if (leftTerm.getIndex() == rightTerm.getIndex()) {
                            count++;
                            if (leftIterator.hasNext()) {
                                leftTerm = leftIterator.next();
                            }
                            if (rightIterator.hasNext()) {
                                rightTerm = rightIterator.next();
                            }
                            leftCursor++;
                            rightCursor++;
                        } else if (leftTerm.getIndex() > rightTerm.getIndex()) {
                            if (rightIterator.hasNext()) {
                                rightTerm = rightIterator.next();
                            }
                            rightCursor++;
                        } else if (leftTerm.getIndex() < rightTerm.getIndex()) {
                            if (leftIterator.hasNext()) {
                                leftTerm = leftIterator.next();
                            }
                            leftCursor++;
                        }
                    }
                    float leftValue = (count + 0F) / leftVector.getElementSize();
                    float rightValue = (count + 0F) / rightVector.getElementSize();
                    associationMatrix.setValue(leftIndex, rightIndex, leftValue);
                    associationMatrix.setValue(rightIndex, leftIndex, rightValue);
                }
                semaphore.release();
            });
        }
        try {
            semaphore.acquire(itemSize - leftItemIndex - 1);
        } catch (Exception exception) {
            throw new ModelException(exception);
        }
    }
}
 
Example 5
Source File: RandomLayer.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Override
public void doForward() {
    MathMatrix weightParameters = parameters.get(WEIGHT_KEY);
    MathMatrix biasParameters = parameters.get(BIAS_KEY);

    MathMatrix inputData = inputKeyValue.getKey();
    MathMatrix middleData = middleKeyValue.getKey();
    middleData.setValues(0F);
    MathMatrix outputData = outputKeyValue.getKey();
    outputData.setValues(0F);
    outputData.getColumnVector(0).copyVector(inputData.getColumnVector(0));
    int numberOfRows = inputData.getRowSize();
    EnvironmentContext context = EnvironmentContext.getContext();
    CountDownLatch latch = new CountDownLatch(numberOfRows);
    for (int rowIndex = 0; rowIndex < numberOfRows; rowIndex++) {
        MathVector inputMajorData = inputData.getRowVector(rowIndex);
        MathVector middleMajorData = middleData.getRowVector(rowIndex);
        MathVector outputMajorData = outputData.getRowVector(rowIndex);
        int numberOfColumns = (int) inputMajorData.getValue(0);
        context.doStructureByAny(rowIndex, () -> {
            for (int columnIndex = 0; columnIndex < numberOfColumns; columnIndex++) {
                MathVector inputMinorData = GlobalVector.detachOf(GlobalVector.class.cast(inputMajorData), columnIndex * numberOfInputs + 1, (columnIndex + 1) * numberOfInputs + 1);
                MathVector middleMinorData = GlobalVector.detachOf(GlobalVector.class.cast(middleMajorData), columnIndex * numberOfOutputs, (columnIndex + 1) * numberOfOutputs);
                MathVector outputMinorData = GlobalVector.detachOf(GlobalVector.class.cast(outputMajorData), columnIndex * numberOfOutputs + 1, (columnIndex + 1) * numberOfOutputs + 1);
                middleMinorData.dotProduct(inputMinorData, weightParameters, false, MathCalculator.SERIAL);
                if (biasParameters != null) {
                    middleMinorData.iterateElement(MathCalculator.SERIAL, (scalar) -> {
                        int index = scalar.getIndex();
                        float value = scalar.getValue();
                        scalar.setValue(value + biasParameters.getValue(0, index));
                    });
                }
                function.forward(middleMajorData, outputMinorData);
            }
            latch.countDown();
        });
    }
    try {
        latch.await();
    } catch (Exception exception) {
        throw new RuntimeException(exception);
    }

    MathMatrix middleError = middleKeyValue.getValue();
    middleError.setValues(0F);

    MathMatrix innerError = outputKeyValue.getValue();
    innerError.setValues(0F);
}