From 1627874e333a8654989738c04956ca1a7fb84705 Mon Sep 17 00:00:00 2001 From: Jozufozu Date: Sun, 9 Apr 2023 14:15:29 -0700 Subject: [PATCH] Putting our plan to the test - Implement plan simplification - Add unit tests for plan execution and simplification --- build.gradle | 5 + .../flywheel/lib/task/BarrierPlan.java | 15 ++ .../flywheel/lib/task/NestedPlan.java | 65 ++++++ .../jozufozu/flywheel/lib/task/PlanUtil.java | 2 +- .../flywheel/lib/task/SimplePlan.java | 13 ++ .../flywheel/lib/task/PlanExecutionTest.java | 192 ++++++++++++++++++ .../lib/task/PlanSimplificationTest.java | 109 ++++++++++ 7 files changed, 400 insertions(+), 1 deletion(-) create mode 100644 src/test/java/com/jozufozu/flywheel/lib/task/PlanExecutionTest.java create mode 100644 src/test/java/com/jozufozu/flywheel/lib/task/PlanSimplificationTest.java diff --git a/build.gradle b/build.gradle index 60cd5e602..2d388dd93 100644 --- a/build.gradle +++ b/build.gradle @@ -129,6 +129,7 @@ minecraft.runs.all { // ^--------------------------------------------------------------------^ dependencies { + testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1' minecraft "net.minecraftforge:forge:${minecraft_version}-${forge_version}" jarJar('org.joml:joml:1.10.5') { @@ -154,6 +155,10 @@ dependencies { } } +test { + useJUnitPlatform() +} + mixin { add sourceSets.main, 'flywheel.refmap.json' } diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/BarrierPlan.java b/src/main/java/com/jozufozu/flywheel/lib/task/BarrierPlan.java index 8b28e9361..c304a0867 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/BarrierPlan.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/BarrierPlan.java @@ -8,4 +8,19 @@ public record BarrierPlan(Plan first, Plan second) implements Plan { public void execute(TaskExecutor taskExecutor, Runnable onCompletion) { first.execute(taskExecutor, () -> second.execute(taskExecutor, onCompletion)); } + + @Override + public Plan maybeSimplify() { + var first = this.first.maybeSimplify(); + var second = this.second.maybeSimplify(); + + if (first == UnitPlan.INSTANCE) { + return second; + } + if (second == UnitPlan.INSTANCE) { + return first; + } + + return new BarrierPlan(first, second); + } } diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/NestedPlan.java b/src/main/java/com/jozufozu/flywheel/lib/task/NestedPlan.java index c440b35df..cec6812e9 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/NestedPlan.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/NestedPlan.java @@ -1,11 +1,17 @@ package com.jozufozu.flywheel.lib.task; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.List; import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.TaskExecutor; public record NestedPlan(List parallelPlans) implements Plan { + public static NestedPlan of(Plan... plans) { + return new NestedPlan(List.of(plans)); + } + @Override public void execute(TaskExecutor taskExecutor, Runnable onCompletion) { if (parallelPlans.isEmpty()) { @@ -26,4 +32,63 @@ public record NestedPlan(List parallelPlans) implements Plan { plan.execute(taskExecutor, wait::decrementAndEventuallyRun); } } + + @Override + public Plan maybeSimplify() { + if (parallelPlans.isEmpty()) { + return UnitPlan.INSTANCE; + } + + if (parallelPlans.size() == 1) { + return parallelPlans.get(0) + .maybeSimplify(); + } + + var simplifiedTasks = new ArrayList(); + var simplifiedPlans = new ArrayList(); + + var toVisit = new ArrayDeque<>(parallelPlans); + while (!toVisit.isEmpty()) { + var plan = toVisit.pop() + .maybeSimplify(); + + if (plan == UnitPlan.INSTANCE) { + continue; + } + + if (plan instanceof SimplePlan simplePlan) { + // merge all simple plans into one + simplifiedTasks.addAll(simplePlan.parallelTasks()); + } else if (plan instanceof NestedPlan nestedPlan) { + // inline and re-visit nested plans + toVisit.addAll(nestedPlan.parallelPlans()); + } else { + // /shrug + simplifiedPlans.add(plan); + } + } + + if (simplifiedTasks.isEmpty() && simplifiedPlans.isEmpty()) { + // everything got simplified away + return UnitPlan.INSTANCE; + } + + if (simplifiedTasks.isEmpty()) { + // no simple plan to create + if (simplifiedPlans.size() == 1) { + // we only contained one complex plan, so we can just return that + return simplifiedPlans.get(0); + } + return new NestedPlan(simplifiedPlans); + } + + if (simplifiedPlans.isEmpty()) { + // we only contained simple plans, so we can just return one + return new SimplePlan(simplifiedTasks); + } + + // we have both simple and complex plans, so we need to create a nested plan + simplifiedPlans.add(new SimplePlan(simplifiedTasks)); + return new NestedPlan(simplifiedPlans); + } } diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java b/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java index 4d19f2e80..6cdbdf196 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java @@ -20,7 +20,7 @@ public class PlanUtil { } public static Plan of(Runnable... tasks) { - return new SimplePlan(List.of(tasks)); + return SimplePlan.of(tasks); } public static Plan onMainThread(Runnable task) { diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/SimplePlan.java b/src/main/java/com/jozufozu/flywheel/lib/task/SimplePlan.java index a20ec1192..45734fbbc 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/SimplePlan.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/SimplePlan.java @@ -6,6 +6,10 @@ import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.TaskExecutor; public record SimplePlan(List parallelTasks) implements Plan { + public static Plan of(Runnable... tasks) { + return new SimplePlan(List.of(tasks)); + } + @Override public void execute(TaskExecutor taskExecutor, Runnable onCompletion) { if (parallelTasks.isEmpty()) { @@ -21,4 +25,13 @@ public record SimplePlan(List parallelTasks) implements Plan { }); } } + + @Override + public Plan maybeSimplify() { + if (parallelTasks.isEmpty()) { + return UnitPlan.INSTANCE; + } + + return this; + } } diff --git a/src/test/java/com/jozufozu/flywheel/lib/task/PlanExecutionTest.java b/src/test/java/com/jozufozu/flywheel/lib/task/PlanExecutionTest.java new file mode 100644 index 000000000..6f572dee7 --- /dev/null +++ b/src/test/java/com/jozufozu/flywheel/lib/task/PlanExecutionTest.java @@ -0,0 +1,192 @@ +package com.jozufozu.flywheel.lib.task; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import com.jozufozu.flywheel.api.task.Plan; +import com.jozufozu.flywheel.backend.task.ParallelTaskExecutor; + +import it.unimi.dsi.fastutil.ints.IntArrayList; + +class PlanExecutionTest { + + protected static final ParallelTaskExecutor EXECUTOR = new ParallelTaskExecutor("PlanTest"); + + @BeforeAll + public static void setUp() { + EXECUTOR.startWorkers(); + } + + @AfterAll + public static void tearDown() { + EXECUTOR.stopWorkers(); + } + + @ParameterizedTest + @ValueSource(ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + void testSynchronizer(int countDown) { + var done = new AtomicBoolean(false); + var synchronizer = new Synchronizer(countDown, () -> done.set(true)); + + for (int i = 0; i < countDown - 1; i++) { + synchronizer.decrementAndEventuallyRun(); + Assertions.assertFalse(done.get(), "Done early at " + i); + } + + synchronizer.decrementAndEventuallyRun(); + Assertions.assertTrue(done.get()); + } + + @ParameterizedTest + @ValueSource(ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + void simpleBarrierSequencing(int barriers) { + var sequence = new IntArrayList(barriers + 1); + var expected = new IntArrayList(barriers + 1); + + var plan = SimplePlan.of(() -> sequence.add(1)); + expected.add(1); + + for (int i = 0; i < barriers; i++) { + final int sequenceNum = i + 2; + expected.add(sequenceNum); + plan = plan.then(SimplePlan.of(() -> sequence.add(sequenceNum))); + } + + runAndWait(plan); + + Assertions.assertEquals(expected, sequence); + } + + @RepeatedTest(10) + void wideBarrierSequencing() { + var lock = new Object(); + var sequence = new IntArrayList(8); + + Runnable addOne = () -> { + synchronized (lock) { + sequence.add(1); + } + }; + Runnable addTwo = () -> { + synchronized (lock) { + sequence.add(2); + } + }; + + var plan = SimplePlan.of(addOne, addOne, addOne, addOne) + .then(SimplePlan.of(addTwo, addTwo, addTwo, addTwo)); + + runAndWait(plan); + + assertExpectedSequence(sequence, 1, 1, 1, 1, 2, 2, 2, 2); + } + + @Test + void simpleNestedPlan() { + var sequence = new IntArrayList(2); + var plan = NestedPlan.of(SimplePlan.of(() -> sequence.add(1))); + runAndWait(plan); + assertExpectedSequence(sequence, 1); + } + + @Test + void manyNestedPlans() { + var counter = new AtomicInteger(0); + var count4 = NestedPlan.of(SimplePlan.of(counter::incrementAndGet, counter::incrementAndGet), SimplePlan.of(counter::incrementAndGet, counter::incrementAndGet)); + + runAndWait(count4); + Assertions.assertEquals(4, counter.get()); + + counter.set(0); + + var count8Barrier = NestedPlan.of(count4, count4); + runAndWait(count8Barrier); + Assertions.assertEquals(8, counter.get()); + } + + @Test + void unitPlan() { + var done = new AtomicBoolean(false); + + UnitPlan.INSTANCE.execute(null, () -> done.set(true)); + + Assertions.assertTrue(done.get()); + } + + @Test + void emptyPlan() { + var done = new AtomicBoolean(false); + + SimplePlan.of() + .execute(null, () -> done.set(true)); + Assertions.assertTrue(done.get()); + + done.set(false); + NestedPlan.of() + .execute(null, () -> done.set(true)); + Assertions.assertTrue(done.get()); + } + + @Test + void mainThreadPlan() { + var done = new AtomicBoolean(false); + var plan = new OnMainThreadPlan(() -> done.set(true)); + + plan.execute(EXECUTOR); + + Assertions.assertFalse(done.get()); + + EXECUTOR.syncPoint(); + + Assertions.assertTrue(done.get()); + } + + private static void assertExpectedSequence(IntArrayList sequence, int... expected) { + Assertions.assertArrayEquals(expected, sequence.toIntArray()); + } + + public static void runAndWait(Plan plan) { + new TestBarrier(plan).runAndWait(); + } + + private static final class TestBarrier { + private final Plan plan; + private boolean done = false; + + private TestBarrier(Plan plan) { + this.plan = plan; + } + + public void runAndWait() { + plan.execute(EXECUTOR, this::doneWithPlan); + + synchronized (this) { + // early exit in case the plan is already done for e.g. UnitPlan + if (done) { + return; + } + + try { + wait(); + } catch (InterruptedException ignored) { + // noop + } + } + } + + public void doneWithPlan() { + synchronized (this) { + done = true; + notifyAll(); + } + } + } +} diff --git a/src/test/java/com/jozufozu/flywheel/lib/task/PlanSimplificationTest.java b/src/test/java/com/jozufozu/flywheel/lib/task/PlanSimplificationTest.java new file mode 100644 index 000000000..ec051599d --- /dev/null +++ b/src/test/java/com/jozufozu/flywheel/lib/task/PlanSimplificationTest.java @@ -0,0 +1,109 @@ +package com.jozufozu.flywheel.lib.task; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import com.jozufozu.flywheel.api.task.Plan; + +public class PlanSimplificationTest { + + public static final Runnable NOOP = () -> { + }; + public static final Plan SIMPLE = SimplePlan.of(NOOP); + + @Test + void emptyPlans() { + var empty = NestedPlan.of(); + Assertions.assertEquals(empty.maybeSimplify(), UnitPlan.INSTANCE); + + var simpleEmpty = SimplePlan.of(); + Assertions.assertEquals(simpleEmpty.maybeSimplify(), UnitPlan.INSTANCE); + } + + @Test + void nestedSimplePlans() { + var twoSimple = NestedPlan.of(SimplePlan.of(NOOP, NOOP, NOOP), SIMPLE); + Assertions.assertEquals(twoSimple.maybeSimplify(), SimplePlan.of(NOOP, NOOP, NOOP, NOOP)); + + var threeSimple = NestedPlan.of(SIMPLE, SIMPLE, SIMPLE); + Assertions.assertEquals(threeSimple.maybeSimplify(), SimplePlan.of(NOOP, NOOP, NOOP)); + } + + @Test + void oneNestedPlan() { + var oneSimple = NestedPlan.of(SIMPLE); + + Assertions.assertEquals(oneSimple.maybeSimplify(), SIMPLE); + + var mainThreadNoop = new OnMainThreadPlan(NOOP); + var oneMainThread = NestedPlan.of(mainThreadNoop); + + Assertions.assertEquals(oneMainThread.maybeSimplify(), mainThreadNoop); + + var barrier = new BarrierPlan(SIMPLE, SIMPLE); + var oneBarrier = NestedPlan.of(barrier); + + Assertions.assertEquals(oneBarrier.maybeSimplify(), barrier); + } + + @Test + void nestedNestedPlan() { + var outer = NestedPlan.of(SIMPLE); + var outermost = NestedPlan.of(outer); + + Assertions.assertEquals(outermost.maybeSimplify(), SIMPLE); + } + + @Test + void nestedUnitPlan() { + var onlyUnit = NestedPlan.of(UnitPlan.INSTANCE, UnitPlan.INSTANCE, UnitPlan.INSTANCE); + Assertions.assertEquals(onlyUnit.maybeSimplify(), UnitPlan.INSTANCE); + + var unitAndSimple = NestedPlan.of(UnitPlan.INSTANCE, UnitPlan.INSTANCE, SIMPLE); + Assertions.assertEquals(unitAndSimple.maybeSimplify(), SIMPLE); + } + + @Test + void complexNesting() { + var mainThreadNoop = new OnMainThreadPlan(NOOP); + + var nested = NestedPlan.of(mainThreadNoop, SIMPLE); + Assertions.assertEquals(nested.maybeSimplify(), nested); // cannot simplify + + var barrier = new BarrierPlan(SIMPLE, SIMPLE); + var complex = NestedPlan.of(barrier, nested); + Assertions.assertEquals(complex.maybeSimplify(), NestedPlan.of(barrier, mainThreadNoop, SIMPLE)); + } + + @Test + void nestedNoSimple() { + var mainThreadNoop = new OnMainThreadPlan(NOOP); + var barrier = new BarrierPlan(SIMPLE, SIMPLE); + var oneMainThread = NestedPlan.of(mainThreadNoop, NestedPlan.of(mainThreadNoop, barrier, barrier)); + + Assertions.assertEquals(oneMainThread.maybeSimplify(), NestedPlan.of(mainThreadNoop, mainThreadNoop, barrier, barrier)); + } + + @Test + void manyNestedButJustOneAfterSimplification() { + var barrier = new BarrierPlan(SIMPLE, SIMPLE); + var oneMainThread = NestedPlan.of(barrier, NestedPlan.of(UnitPlan.INSTANCE, UnitPlan.INSTANCE)); + + Assertions.assertEquals(oneMainThread.maybeSimplify(), barrier); + } + + @Test + void barrierPlan() { + var doubleUnit = new BarrierPlan(UnitPlan.INSTANCE, UnitPlan.INSTANCE); + Assertions.assertEquals(doubleUnit.maybeSimplify(), UnitPlan.INSTANCE); + + var simpleThenUnit = new BarrierPlan(SIMPLE, UnitPlan.INSTANCE); + Assertions.assertEquals(simpleThenUnit.maybeSimplify(), SIMPLE); + + var unitThenSimple = new BarrierPlan(UnitPlan.INSTANCE, SIMPLE); + Assertions.assertEquals(unitThenSimple.maybeSimplify(), SIMPLE); + + var simpleThenSimple = new BarrierPlan(SIMPLE, SIMPLE); + Assertions.assertEquals(simpleThenSimple.maybeSimplify(), new BarrierPlan(SIMPLE, SIMPLE)); + } +}