cern.jet.random.Normal Java Examples

The following examples show how to use cern.jet.random.Normal. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: CrazyCharactersSample.java    From micrometer with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();

    String badCounterName = "\"\';^*()!~`_./?a{counter}:with123 weirdChars";
    String badTagName = "\"\';^*()!~`_./?a{tag}:with123 weirdChars";
    String badValueName = "\"\';^*()!~`_./?a{value}:with123 weirdChars";
    Counter counter = registry.counter(badCounterName,
            badTagName,
            badValueName);

    RandomEngine r = new MersenneTwister64(0);
    Normal dist = new Normal(0, 1, r);

    Flux.interval(Duration.ofMillis(10))
            .doOnEach(d -> {
                if (dist.nextDouble() + 0.1 > 0) {
                    counter.increment();
                }
            })
            .blockLast();
}
 
Example #2
Source File: CounterSample.java    From micrometer with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();
    Counter counter = registry.counter("counter", "method", "actual");

    AtomicInteger n = new AtomicInteger(0);
    registry.more().counter("counter", Tags.of("method", "function"), n);

    RandomEngine r = new MersenneTwister64(0);
    Normal dist = new Normal(0, 1, r);

    Flux.interval(Duration.ofMillis(10))
            .doOnEach(d -> {
                if (dist.nextDouble() + 0.1 > 0) {
                    counter.increment();
                    n.incrementAndGet();
                }
            })
            .blockLast();
}
 
Example #3
Source File: IndependentNormalDistributionSampler.java    From beast-mcmc with GNU Lesser General Public License v2.1 6 votes vote down vote up
public IndependentNormalDistributionSampler(Variable variable, NormalDistributionModel model, double weight, boolean updateAllIndependently) {
	
	this.variable = variable;
	this.model = model;
	this.updateAllIndependently = updateAllIndependently;
	setWeight(weight);
	
	if (TRY_COLT) {
           randomEngine = new MersenneTwister(MathUtils.nextInt());
           //create standard normal distribution, internal states will be bypassed anyway
           coltNormal = new Normal(0.0, 1.0, randomEngine);
       } else {
       	//no random draw with specified mean and stdev implemented in the normal distribution in BEAST (as far as I know)
       	throw new RuntimeException("Normal distribution in BEAST still needs a random sampler.");
       }
	
}
 
Example #4
Source File: GibbsIndependentNormalDistributionOperator.java    From beast-mcmc with GNU Lesser General Public License v2.1 6 votes vote down vote up
public GibbsIndependentNormalDistributionOperator(Variable variable, NormalDistributionModel model, double weight, boolean updateAllIndependently) {
	
	this.variable = variable;
	this.model = model;
	this.updateAllIndependently = updateAllIndependently;
	setWeight(weight);
	
	if (TRY_COLT) {
           randomEngine = new MersenneTwister(MathUtils.nextInt());
           //create standard normal distribution, internal states will be bypassed anyway
           //takes mean and standard deviation
           coltNormal = new Normal(0.0, 1.0, randomEngine);
       } else {
       	//no random draw with specified mean and stdev implemented in the normal distribution in BEAST (as far as I know)
       	throw new RuntimeException("Normal distribution in BEAST still needs a random sampler.");
       }
	
}
 
Example #5
Source File: GibbsIndependentJointNormalGammaOperator.java    From beast-mcmc with GNU Lesser General Public License v2.1 6 votes vote down vote up
public GibbsIndependentJointNormalGammaOperator(Variable mean, Variable precision, NormalDistributionModel model, GammaDistribution gamma, double weight, boolean updateAllIndependently) {
	
	this.mean = mean;
       this.precision = precision;
	this.model = model;
       this.gamma = gamma;
	this.updateAllIndependently = updateAllIndependently;
	setWeight(weight);
	
	if (TRY_COLT) {
           randomEngine = new MersenneTwister(MathUtils.nextInt());
           //create standard normal distribution, internal states will be bypassed anyway
           //takes mean and standard deviation
           coltNormal = new Normal(0.0, 1.0, randomEngine);
           //coltGamma = new Gamma(gamma.getShape(), 1.0/gamma.getScale(), randomEngine);
       } else {
       	//no random draw with specified mean and stdev implemented in the normal distribution in BEAST (as far as I know)
       	throw new RuntimeException("Normal distribution in BEAST still needs a random sampler.");
       }
	
}
 
Example #6
Source File: Frugal2UTest.java    From streaminer with Apache License 2.0 6 votes vote down vote up
@Test
public void testOffer() throws QuantilesException {
    System.out.println("offer");

    double[] quantiles = new double[]{0.05, 0.25, 0.5, 0.75, 0.95};
    Frugal2U instance = new Frugal2U(quantiles, 0);
    ExactQuantilesAll<Integer> exact = new ExactQuantilesAll<Integer>();
    
    RandomEngine r = new MersenneTwister64(0);
    Normal dist = new Normal(100, 50, r);
    int numSamples = 1000;
            
    for(int i = 0; i < numSamples; ++i) {
        int num = (int) Math.max(0, dist.nextDouble());
        instance.offer(num);
        exact.offer(num);
    }
    
    System.out.println("Q\tEst\tExact");
    for (double q : quantiles) {
        System.out.println(q + "\t" + instance.getQuantile(q) + "\t" + exact.getQuantile(q));
    }
    
    
}
 
Example #7
Source File: TimerSample.java    From micrometer with Apache License 2.0 5 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();
    Timer timer = Timer.builder("timer")
            .publishPercentileHistogram()
            .publishPercentiles(0.5, 0.95, 0.99)
            .serviceLevelObjectives(Duration.ofMillis(275), Duration.ofMillis(300), Duration.ofMillis(500))
            .distributionStatisticExpiry(Duration.ofSeconds(10))
            .distributionStatisticBufferLength(3)
            .register(registry);

    AtomicLong totalCount = new AtomicLong();
    AtomicLong totalTime = new AtomicLong();
    FunctionTimer.builder("ftimer", totalCount, t -> totalCount.get(), t -> totalTime.get(), TimeUnit.MILLISECONDS)
            .register(registry);

    RandomEngine r = new MersenneTwister64(0);
    Normal incomingRequests = new Normal(0, 1, r);
    Normal duration = new Normal(250, 50, r);

    AtomicInteger latencyForThisSecond = new AtomicInteger(duration.nextInt());
    Flux.interval(Duration.ofSeconds(1))
            .doOnEach(d -> latencyForThisSecond.set(duration.nextInt()))
            .subscribe();

    // the potential for an "incoming request" every 10 ms
    Flux.interval(Duration.ofMillis(10))
            .doOnEach(d -> {
                if (incomingRequests.nextDouble() + 0.4 > 0) {
                    // pretend the request took some amount of time, such that the time is
                    // distributed normally with a mean of 250ms
                    int latency = latencyForThisSecond.get();
                    timer.record(latency, TimeUnit.MILLISECONDS);
                    totalTime.addAndGet(latency);
                    totalCount.incrementAndGet();
                }
            })
            .blockLast();
}
 
Example #8
Source File: GaugeSample.java    From micrometer with Apache License 2.0 5 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();
    AtomicLong n = new AtomicLong();
    registry.gauge("gauge", Tags.of("k", "v"), n);
    registry.gauge("gauge", Tags.of("k", "v1"), n, n2 -> n2.get() - 1);

    RandomEngine r = new MersenneTwister64(0);
    Normal dist = new Normal(0, 10, r);

    Flux.interval(Duration.ofSeconds(5))
            .doOnEach(d -> n.set(Math.abs(dist.nextInt())))
            .blockLast();
}
 
Example #9
Source File: LongTaskTimerSample.java    From micrometer with Apache License 2.0 5 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();
    LongTaskTimer timer = registry.more().longTaskTimer("longTaskTimer");

    RandomEngine r = new MersenneTwister64(0);
    Normal incomingRequests = new Normal(0, 1, r);
    Normal duration = new Normal(30, 50, r);

    AtomicInteger latencyForThisSecond = new AtomicInteger(duration.nextInt());
    Flux.interval(Duration.ofSeconds(1))
            .doOnEach(d -> latencyForThisSecond.set(duration.nextInt()))
            .subscribe();

    final Map<LongTaskTimer.Sample, CountDownLatch> tasks = new ConcurrentHashMap<>();

    // the potential for an "incoming request" every 10 ms
    Flux.interval(Duration.ofSeconds(1))
            .doOnEach(d -> {
                if (incomingRequests.nextDouble() + 0.4 > 0 && tasks.isEmpty()) {
                    int taskDur;
                    while ((taskDur = duration.nextInt()) < 0);
                    synchronized (tasks) {
                        tasks.put(timer.start(), new CountDownLatch(taskDur));
                    }
                }

                synchronized (tasks) {
                    for (Map.Entry<LongTaskTimer.Sample, CountDownLatch> e : tasks.entrySet()) {
                        e.getValue().countDown();
                        if (e.getValue().getCount() == 0) {
                            e.getKey().stop();
                            tasks.remove(e.getKey());
                        }
                    }
                }
            })
            .blockLast();
}
 
Example #10
Source File: TimerMaximumThroughputSample.java    From micrometer with Apache License 2.0 5 votes vote down vote up
public static void main(String[] args) {
        MeterRegistry registry = SampleConfig.myMonitoringSystem();
        Timer timer = Timer.builder("timer")
                .publishPercentileHistogram()
//                .publishPercentiles(0.5, 0.95, 0.99)
                .serviceLevelObjectives(Duration.ofMillis(275), Duration.ofMillis(300), Duration.ofMillis(500))
                .distributionStatisticExpiry(Duration.ofSeconds(10))
                .distributionStatisticBufferLength(3)
                .register(registry);

        RandomEngine r = new MersenneTwister64(0);
        Normal duration = new Normal(250, 50, r);

        AtomicInteger latencyForThisSecond = new AtomicInteger(duration.nextInt());
        Flux.interval(Duration.ofSeconds(1))
                .doOnEach(d -> latencyForThisSecond.set(duration.nextInt()))
                .subscribe();

        Stream<Integer> infiniteStream = Stream.iterate(0, i -> (i + 1) % 1000);
        Flux.fromStream(infiniteStream)
                .parallel(4)
                .runOn(Schedulers.parallel())
                .doOnEach(d -> timer.record(latencyForThisSecond.get(), TimeUnit.MILLISECONDS))
                .subscribe();

        Flux.never().blockLast();
    }
 
Example #11
Source File: FunctionTimerSample.java    From micrometer with Apache License 2.0 5 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();

    Timer timer = Timer.builder("timer")
        .publishPercentiles(0.5, 0.95)
        .register(registry);

    Object placeholder = new Object();
    AtomicLong totalTimeNanos = new AtomicLong(0);
    AtomicLong totalCount = new AtomicLong(0);

    FunctionTimer.builder("ftimer", placeholder, p -> totalCount.get(), p -> totalTimeNanos.get(), TimeUnit.NANOSECONDS)
        .register(registry);

    RandomEngine r = new MersenneTwister64(0);
    Normal incomingRequests = new Normal(0, 1, r);
    Normal duration = new Normal(250, 50, r);

    AtomicInteger latencyForThisSecond = new AtomicInteger(duration.nextInt());
    Flux.interval(Duration.ofSeconds(1))
        .doOnEach(d -> latencyForThisSecond.set(duration.nextInt()))
        .subscribe();

    // the potential for an "incoming request" every 10 ms
    Flux.interval(Duration.ofMillis(10))
        .doOnEach(d -> {
            if (incomingRequests.nextDouble() + 0.4 > 0) {
                // pretend the request took some amount of time, such that the time is
                // distributed normally with a mean of 250ms
                timer.record(latencyForThisSecond.get(), TimeUnit.MILLISECONDS);
                totalCount.incrementAndGet();
                totalTimeNanos.addAndGet((long) TimeUtils.millisToUnit(latencyForThisSecond.get(), TimeUnit.NANOSECONDS));
            }
        })
        .blockLast();
}
 
Example #12
Source File: SimulatedEndpointInstrumentation.java    From micrometer with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {
    MeterRegistry registry = SampleConfig.myMonitoringSystem();

    Timer e1Success = Timer.builder("http.server.requests")
        .tags("uri", "/api/bar")
        .tags("response", "200")
        .publishPercentiles(0.5, 0.95)
        .register(registry);

    Timer e2Success = Timer.builder("http.server.requests")
        .tags("uri", "/api/foo")
        .tags("response", "200")
        .publishPercentiles(0.5, 0.95)
        .register(registry);

    Timer e1Fail = Timer.builder("http.server.requests")
        .tags("uri", "/api/bar")
        .tags("response", "500")
        .publishPercentiles(0.5, 0.95)
        .register(registry);

    Timer e2Fail = Timer.builder("http.server.requests")
        .tags("uri", "/api/foo")
        .tags("response", "500")
        .publishPercentiles(0.5, 0.95)
        .register(registry);

    RandomEngine r = new MersenneTwister64(0);
    Normal incomingRequests = new Normal(0, 1, r);
    Normal successOrFail = new Normal(0, 1, r);

    Normal duration = new Normal(250, 50, r);
    Normal duration2 = new Normal(250, 50, r);

    AtomicInteger latencyForThisSecond = new AtomicInteger(duration.nextInt());
    Flux.interval(Duration.ofSeconds(1))
        .doOnEach(d -> latencyForThisSecond.set(duration.nextInt()))
        .subscribe();

    AtomicInteger latencyForThisSecond2 = new AtomicInteger(duration2.nextInt());
    Flux.interval(Duration.ofSeconds(1))
        .doOnEach(d -> latencyForThisSecond2.set(duration2.nextInt()))
        .subscribe();

    // the potential for an "incoming request" every 10 ms
    Flux.interval(Duration.ofMillis(10))
        .doOnEach(d -> {
            // are we going to receive a request for /api/foo?
            if (incomingRequests.nextDouble() + 0.4 > 0) {
                if (successOrFail.nextDouble() + 0.8 > 0) {
                    // pretend the request took some amount of time, such that the time is
                    // distributed normally with a mean of 250ms
                    e1Success.record(latencyForThisSecond.get(), TimeUnit.MILLISECONDS);
                }
                else {
                    e1Fail.record(latencyForThisSecond.get(), TimeUnit.MILLISECONDS);
                }
            }
        })
        .subscribe();

    // the potential for an "incoming request" every 1 ms
    Flux.interval(Duration.ofMillis(1))
        .doOnEach(d -> {
            // are we going to receive a request for /api/bar?
            if (incomingRequests.nextDouble() + 0.4 > 0) {
                if (successOrFail.nextDouble() + 0.8 > 0) {
                    // pretend the request took some amount of time, such that the time is
                    // distributed normally with a mean of 250ms
                    e2Success.record(latencyForThisSecond2.get(), TimeUnit.MILLISECONDS);
                }
                else {
                    e2Fail.record(latencyForThisSecond2.get(), TimeUnit.MILLISECONDS);
                }
            }
        })
        .blockLast();
}
 
Example #13
Source File: VMPNormalTest.java    From toolbox with Apache License 2.0 4 votes vote down vote up
public static void test2() throws IOException, ClassNotFoundException{

        Variables variables = new Variables();
        Variable varA = variables.newGaussianVariable("A");
        Variable varB = variables.newGaussianVariable("B");

        DAG dag = new DAG(variables);

        dag.getParentSet(varB).addParent(varA);
        BayesianNetwork bn = new BayesianNetwork(dag);

        Normal distA = bn.getConditionalDistribution(varA);
        ConditionalLinearGaussian distB = bn.getConditionalDistribution(varB);

        distA.setMean(1);
        distA.setVariance(0.25);
        distB.setIntercept(1);
        //distB.setCoeffParents(new double[]{1});
        distB.setCoeffForParent(varA, 1);
        distB.setVariance(0.25);

        if (Main.VERBOSE) System.out.println(bn.toString());

        double meanPA =  distA.getMean();
        double sdPA =  distA.getSd();

        double b0PB =  distB.getIntercept();
        //double b1PB = distB.getCoeffParents()[0];
        double b1PB = distB.getCoeffForParent(varA);
        double sdPB =  distB.getSd();

        VMP vmp = new VMP();
        vmp.setTestELBO(true);
        vmp.setMaxIter(100);
        vmp.setThreshold(0.0001);
        vmp.setModel(bn);

        EF_Normal qADist = ((EF_Normal) vmp.getNodes().get(0).getQDist());
        EF_Normal qBDist = ((EF_Normal) vmp.getNodes().get(1).getQDist());

        double meanQA= qADist.getMomentParameters().get(0);
        double sdQA= Math.sqrt(qADist.getMomentParameters().get(1) - qADist.getMomentParameters().get(0) * qADist.getMomentParameters().get(0));

        double meanQB= qBDist.getMomentParameters().get(0);
        double sdQB= Math.sqrt(qBDist.getMomentParameters().get(1) - qBDist.getMomentParameters().get(0)*qBDist.getMomentParameters().get(0));

        vmp.runInference();

        Normal postA = vmp.getPosterior(varA);
        if (Main.VERBOSE) System.out.println("P(A) = " + postA.toString());
        Normal postB = ((Normal)vmp.getPosterior(varB));
        if (Main.VERBOSE) System.out.println("P(B) = " + postB.toString());

        boolean convergence = false;
        double oldvalue = 0;

        while(!convergence){

            sdQA = Math.sqrt(Math.pow(b1PB*b1PB/(sdPB*sdPB) + 1.0/(sdPA*sdPA),-1));
            meanQA = sdQA*sdQA*(b1PB*meanQB/(sdPB*sdPB) - b0PB*b1PB/(sdPB*sdPB) + meanPA/(sdPA*sdPA));

            sdQB = sdPB;
            meanQB = sdQB*sdQB*(b0PB/(sdPB*sdPB) + b1PB*meanQA/(sdPB*sdPB));

            if (Math.abs(sdQA + meanQA + sdQB + meanQB - oldvalue) < 0.001) {
                convergence = true;
            }

            oldvalue = sdQA + meanQA + sdQB + meanQB ;
        }

        if (Main.VERBOSE) System.out.println("Mean and Sd of A: " + meanQA +", " + sdQA );
        if (Main.VERBOSE) System.out.println("Mean and Sd of B: " + meanQB +", " + sdQB );

        Assert.assertEquals(postA.getMean(),meanQA,0.01);
        Assert.assertEquals(postA.getSd(),sdQA,0.01);
        Assert.assertEquals(postB.getMean(),meanQB,0.01);
        Assert.assertEquals(postB.getSd(),sdQB,0.01);
    }
 
Example #14
Source File: VMPNormalTest.java    From toolbox with Apache License 2.0 4 votes vote down vote up
public static void test4() throws IOException, ClassNotFoundException{

        Variables variables = new Variables();
        Variable varA = variables.newGaussianVariable("A");
        Variable varB = variables.newGaussianVariable("B");
        Variable varC = variables.newGaussianVariable("C");

        DAG dag = new DAG(variables);

        dag.getParentSet(varC).addParent(varA);
        dag.getParentSet(varC).addParent(varB);

        BayesianNetwork bn = new BayesianNetwork(dag);

        Normal distA = bn.getConditionalDistribution(varA);
        Normal distB = bn.getConditionalDistribution(varB);
        ConditionalLinearGaussian distC = bn.getConditionalDistribution(varC);

        distA.setMean(1);
        distA.setVariance(0.25);

        distB.setMean(1.2);
        distB.setVariance(0.64);

        distC.setIntercept(1);
        //distC.setCoeffParents(new double[]{1, 1});
        distC.setCoeffForParent(varA, 1);
        distC.setCoeffForParent(varB, 1);
        distC.setVariance(0.25);

        if (Main.VERBOSE) System.out.println(bn.toString());

        double meanPA =  distA.getMean();
        double sdPA =  distA.getSd();

        double meanPB =  distB.getMean();
        double sdPB =  distB.getSd();

        double b0PC =  distC.getIntercept();
        //double b1PC = distC.getCoeffParents()[0];
        //double b2PC = distC.getCoeffParents()[1];
        double b1PC = distC.getCoeffForParent(varA);
        double b2PC = distC.getCoeffForParent(varB);
        double sdPC =  distC.getSd();

        VMP vmp = new VMP();
        vmp.setTestELBO(true);
        vmp.setMaxIter(100);
        vmp.setThreshold(0.0001);
        vmp.setModel(bn);

        EF_Normal qADist = ((EF_Normal) vmp.getNodes().get(0).getQDist());
        EF_Normal qBDist = ((EF_Normal) vmp.getNodes().get(1).getQDist());
        EF_Normal qCDist = ((EF_Normal) vmp.getNodes().get(2).getQDist());

        double meanQA= qADist.getMomentParameters().get(0);
        double sdQA= Math.sqrt(qADist.getMomentParameters().get(1) - qADist.getMomentParameters().get(0) * qADist.getMomentParameters().get(0));

        double meanQB= qBDist.getMomentParameters().get(0);
        double sdQB= Math.sqrt(qBDist.getMomentParameters().get(1) - qBDist.getMomentParameters().get(0) * qBDist.getMomentParameters().get(0));

        double meanQC= 0.7;

        HashMapAssignment assignment = new HashMapAssignment(1);
        assignment.setValue(varC, 0.7);

        vmp.setEvidence(assignment);

        vmp.runInference();

        Normal postA = vmp.getPosterior(varA);
        if (Main.VERBOSE) System.out.println("P(A) = " + postA.toString());
        Normal postB = vmp.getPosterior(varB);
        if (Main.VERBOSE) System.out.println("P(B) = " + postB.toString());

        boolean convergence = false;
        double oldvalue = 0;

        while(!convergence){

            sdQA = Math.sqrt(Math.pow(b1PC*b1PC/(sdPC*sdPC) + 1.0/(sdPA*sdPA),-1));
            meanQA = sdQA*sdQA*(b1PC*meanQC/(sdPC*sdPC) - b0PC*b1PC/(sdPC*sdPC) - b1PC*b2PC*meanQB/(sdPC*sdPC) + meanPA/(sdPA*sdPA));

            sdQB = Math.sqrt(Math.pow(b2PC*b2PC/(sdPC*sdPC) + 1.0/(sdPB*sdPB),-1));
            meanQB = sdQB*sdQB*(b2PC*meanQC/(sdPC*sdPC) - b0PC*b2PC/(sdPC*sdPC) - b1PC*b2PC*meanQA/(sdPC*sdPC) + meanPB/(sdPB*sdPB));

            if (Math.abs(sdQA + meanQA + sdQB + meanQB - oldvalue) < 0.001) {
                convergence = true;
            }
            oldvalue = sdQA + meanQA + sdQB + meanQB;
        }

        if (Main.VERBOSE) System.out.println("Mean and Sd of A: " + meanQA +", " + sdQA );
        if (Main.VERBOSE) System.out.println("Mean and Sd of B: " + meanQB +", " + sdQB );

        Assert.assertEquals(postA.getMean(),meanQA,0.01);
        Assert.assertEquals(postA.getSd(),sdQA,0.01);
        Assert.assertEquals(postB.getMean(),meanQB,0.01);
        Assert.assertEquals(postB.getSd(),sdQB,0.01);
    }
 
Example #15
Source File: VMPNormalTest.java    From toolbox with Apache License 2.0 4 votes vote down vote up
public static void test6() throws IOException, ClassNotFoundException{

        Variables variables = new Variables();
        Variable varA = variables.newGaussianVariable("A");
        Variable varB = variables.newGaussianVariable("B");
        Variable varC = variables.newGaussianVariable("C");

        DAG dag = new DAG(variables);

        dag.getParentSet(varA).addParent(varC);
        dag.getParentSet(varB).addParent(varC);

        BayesianNetwork bn = new BayesianNetwork(dag);

        ConditionalLinearGaussian distA = bn.getConditionalDistribution(varA);
        ConditionalLinearGaussian distB = bn.getConditionalDistribution(varB);
        Normal distC = bn.getConditionalDistribution(varC);

        distA.setIntercept(1);
        //distA.setCoeffParents(new double[]{1});
        distA.setCoeffForParent(varC, 1);
        distA.setVariance(0.25);

        distB.setIntercept(1.5);
        //distB.setCoeffParents(new double[]{1});
        distB.setCoeffForParent(varC, 1);
        distB.setVariance(0.64);

        distC.setMean(1);
        distC.setVariance(0.25);

        if (Main.VERBOSE) System.out.println(bn.toString());

        double b0PA =  distA.getIntercept();
        //double b1PA = distA.getCoeffParents()[0];
        double b1PA = distA.getCoeffForParent(varC);
        double sdPA =  distA.getSd();

        double b0PB =  distB.getIntercept();
        //double b1PB = distB.getCoeffParents()[0];
        double b1PB = distB.getCoeffForParent(varC);
        double sdPB =  distB.getSd();

        double meanPC =  distC.getMean();
        double sdPC =  distC.getSd();

        VMP vmp = new VMP();
        vmp.setTestELBO(true);
        vmp.setMaxIter(100);
        vmp.setThreshold(0.0001);
        vmp.setModel(bn);

        EF_Normal qADist = ((EF_Normal) vmp.getNodes().get(0).getQDist());
        EF_Normal qBDist = ((EF_Normal) vmp.getNodes().get(1).getQDist());
        EF_Normal qCDist = ((EF_Normal) vmp.getNodes().get(2).getQDist());

        double meanQA= 0.7;
        double meanQB= 0.2;

        HashMapAssignment assignment = new HashMapAssignment(1);
        assignment.setValue(varA, 0.7);
        assignment.setValue(varB, 0.2);

        double meanQC= qCDist.getMomentParameters().get(0);
        double sdQC= Math.sqrt(qCDist.getMomentParameters().get(1) - qCDist.getMomentParameters().get(0)*qCDist.getMomentParameters().get(0));

        vmp.setEvidence(assignment);
        vmp.runInference();

        Normal postC = ((Normal)vmp.getPosterior(varC));
        if (Main.VERBOSE) System.out.println("P(C) = " + postC.toString());

        boolean convergence = false;
        double oldvalue = 0;

        while(!convergence){
            sdQC = Math.sqrt(Math.pow(b1PA*b1PA/(sdPA*sdPA) + b1PB*b1PB/(sdPB*sdPB) + 1.0/(sdPC*sdPC),-1));
            meanQC = sdQC*sdQC*(b1PA*meanQA/(sdPA*sdPA) - b0PA*b1PA/(sdPA*sdPA) + b1PB*meanQB/(sdPB*sdPB) - b0PB*b1PB/(sdPB*sdPB) + meanPC/(sdPC*sdPC));

            if (Math.abs(sdQC + meanQC - oldvalue) < 0.001) {
                convergence = true;
            }
            oldvalue = sdQC + meanQC;
        }

        if (Main.VERBOSE) System.out.println("Mean and Sd of C: " + meanQC +", " + sdQC );

        Assert.assertEquals(postC.getMean(),meanQC,0.01);
        Assert.assertEquals(postC.getSd(),sdQC,0.01);
    }
 
Example #16
Source File: VMPNormalTest.java    From toolbox with Apache License 2.0 4 votes vote down vote up
public static void test8() throws IOException, ClassNotFoundException{

        Variables variables = new Variables();
        Variable varA = variables.newGaussianVariable("A");
        Variable varB = variables.newGaussianVariable("B");
        Variable varC = variables.newGaussianVariable("C");

        DAG dag = new DAG(variables);

        dag.getParentSet(varB).addParent(varA);
        dag.getParentSet(varC).addParent(varB);

        BayesianNetwork bn = new BayesianNetwork(dag);

        Normal distA = bn.getConditionalDistribution(varA);
        ConditionalLinearGaussian distB = bn.getConditionalDistribution(varB);
        ConditionalLinearGaussian distC = bn.getConditionalDistribution(varC);

        distA.setMean(1);
        distA.setVariance(0.25);

        distB.setIntercept(1);
        //distB.setCoeffParents(new double[]{1});
        distB.setCoeffForParent(varA, 1);
        distB.setVariance(0.04);

        distC.setIntercept(1);
        //distC.setCoeffParents(new double[]{1});
        distC.setCoeffForParent(varB, 1);
        distC.setVariance(0.25);


        if (Main.VERBOSE) System.out.println(bn.toString());

        double meanPA =  distA.getMean();
        double sdPA =  distA.getSd();

        double b0PB =  distB.getIntercept();
        //double b1PB = distB.getCoeffParents()[0];
        double b1PB = distB.getCoeffForParent(varA);
        double sdPB =  distB.getSd();

        double b0PC =  distC.getIntercept();
        //double b1PC = distC.getCoeffParents()[0];
        double b1PC = distC.getCoeffForParent(varB);
        double sdPC =  distC.getSd();

        VMP vmp = new VMP();
        vmp.setTestELBO(true);
        vmp.setMaxIter(100);
        vmp.setThreshold(0.0001);
        vmp.setModel(bn);

        EF_Normal qADist = ((EF_Normal) vmp.getNodes().get(0).getQDist());
        EF_Normal qBDist = ((EF_Normal) vmp.getNodes().get(1).getQDist());
        EF_Normal qCDist = ((EF_Normal) vmp.getNodes().get(2).getQDist());

        double meanQA= qADist.getMomentParameters().get(0);
        double sdQA= Math.sqrt(qADist.getMomentParameters().get(1) - qADist.getMomentParameters().get(0) * qADist.getMomentParameters().get(0));

        double meanQC= qCDist.getMomentParameters().get(0);
        double sdQC= Math.sqrt(qCDist.getMomentParameters().get(1) - qCDist.getMomentParameters().get(0)*qCDist.getMomentParameters().get(0));

        double meanQB= 0.4;

        HashMapAssignment assignment = new HashMapAssignment(1);
        assignment.setValue(varB, 0.4);

        vmp.setEvidence(assignment);
        vmp.runInference();

        Normal postA = vmp.getPosterior(varA);
        if (Main.VERBOSE) System.out.println("P(A) = " + postA.toString());
        Normal postC = ((Normal)vmp.getPosterior(varC));
        if (Main.VERBOSE) System.out.println("P(C) = " + postC.toString());

        boolean convergence = false;
        double oldvalue = 0;

        while(!convergence){

            sdQA = Math.sqrt(Math.pow(b1PB*b1PB/(sdPB*sdPB) + 1.0/(sdPA*sdPA),-1));
            meanQA = sdQA*sdQA*(b1PB*meanQB/(sdPB*sdPB) - b0PB*b1PB/(sdPB*sdPB) + meanPA/(sdPA*sdPA));

            sdQC = sdPC;
            meanQC = sdQC*sdQC*(b0PC/(sdPC*sdPC) + b1PC*meanQB/(sdPC*sdPC));

            if (Math.abs(sdQA + meanQA + sdQC + meanQC - oldvalue) < 0.001) {
                convergence = true;
            }
            oldvalue = sdQA + meanQA + + sdQC + meanQC;
        }

        if (Main.VERBOSE) System.out.println("Mean and Sd of A: " + meanQA +", " + sdQA );
        if (Main.VERBOSE) System.out.println("Mean and Sd of C: " + meanQC +", " + sdQC );

        Assert.assertEquals(postA.getMean(),meanQA,0.01);
        Assert.assertEquals(postA.getSd(),sdQA,0.01);
        Assert.assertEquals(postC.getMean(),meanQC,0.01);
        Assert.assertEquals(postC.getSd(),sdQC,0.01);
    }