Putting our plan to the test

- Implement plan simplification
- Add unit tests for plan execution and simplification
This commit is contained in:
Jozufozu 2023-04-09 14:15:29 -07:00
parent fb11f29010
commit 1627874e33
7 changed files with 400 additions and 1 deletions

View file

@ -129,6 +129,7 @@ minecraft.runs.all {
// ^--------------------------------------------------------------------^
dependencies {
testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1'
minecraft "net.minecraftforge:forge:${minecraft_version}-${forge_version}"
jarJar('org.joml:joml:1.10.5') {
@ -154,6 +155,10 @@ dependencies {
}
}
test {
useJUnitPlatform()
}
mixin {
add sourceSets.main, 'flywheel.refmap.json'
}

View file

@ -8,4 +8,19 @@ public record BarrierPlan(Plan first, Plan second) implements Plan {
public void execute(TaskExecutor taskExecutor, Runnable onCompletion) {
first.execute(taskExecutor, () -> second.execute(taskExecutor, onCompletion));
}
@Override
public Plan maybeSimplify() {
var first = this.first.maybeSimplify();
var second = this.second.maybeSimplify();
if (first == UnitPlan.INSTANCE) {
return second;
}
if (second == UnitPlan.INSTANCE) {
return first;
}
return new BarrierPlan(first, second);
}
}

View file

@ -1,11 +1,17 @@
package com.jozufozu.flywheel.lib.task;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
public record NestedPlan(List<Plan> parallelPlans) implements Plan {
public static NestedPlan of(Plan... plans) {
return new NestedPlan(List.of(plans));
}
@Override
public void execute(TaskExecutor taskExecutor, Runnable onCompletion) {
if (parallelPlans.isEmpty()) {
@ -26,4 +32,63 @@ public record NestedPlan(List<Plan> parallelPlans) implements Plan {
plan.execute(taskExecutor, wait::decrementAndEventuallyRun);
}
}
@Override
public Plan maybeSimplify() {
if (parallelPlans.isEmpty()) {
return UnitPlan.INSTANCE;
}
if (parallelPlans.size() == 1) {
return parallelPlans.get(0)
.maybeSimplify();
}
var simplifiedTasks = new ArrayList<Runnable>();
var simplifiedPlans = new ArrayList<Plan>();
var toVisit = new ArrayDeque<>(parallelPlans);
while (!toVisit.isEmpty()) {
var plan = toVisit.pop()
.maybeSimplify();
if (plan == UnitPlan.INSTANCE) {
continue;
}
if (plan instanceof SimplePlan simplePlan) {
// merge all simple plans into one
simplifiedTasks.addAll(simplePlan.parallelTasks());
} else if (plan instanceof NestedPlan nestedPlan) {
// inline and re-visit nested plans
toVisit.addAll(nestedPlan.parallelPlans());
} else {
// /shrug
simplifiedPlans.add(plan);
}
}
if (simplifiedTasks.isEmpty() && simplifiedPlans.isEmpty()) {
// everything got simplified away
return UnitPlan.INSTANCE;
}
if (simplifiedTasks.isEmpty()) {
// no simple plan to create
if (simplifiedPlans.size() == 1) {
// we only contained one complex plan, so we can just return that
return simplifiedPlans.get(0);
}
return new NestedPlan(simplifiedPlans);
}
if (simplifiedPlans.isEmpty()) {
// we only contained simple plans, so we can just return one
return new SimplePlan(simplifiedTasks);
}
// we have both simple and complex plans, so we need to create a nested plan
simplifiedPlans.add(new SimplePlan(simplifiedTasks));
return new NestedPlan(simplifiedPlans);
}
}

View file

@ -20,7 +20,7 @@ public class PlanUtil {
}
public static Plan of(Runnable... tasks) {
return new SimplePlan(List.of(tasks));
return SimplePlan.of(tasks);
}
public static Plan onMainThread(Runnable task) {

View file

@ -6,6 +6,10 @@ import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
public record SimplePlan(List<Runnable> parallelTasks) implements Plan {
public static Plan of(Runnable... tasks) {
return new SimplePlan(List.of(tasks));
}
@Override
public void execute(TaskExecutor taskExecutor, Runnable onCompletion) {
if (parallelTasks.isEmpty()) {
@ -21,4 +25,13 @@ public record SimplePlan(List<Runnable> parallelTasks) implements Plan {
});
}
}
@Override
public Plan maybeSimplify() {
if (parallelTasks.isEmpty()) {
return UnitPlan.INSTANCE;
}
return this;
}
}

View file

@ -0,0 +1,192 @@
package com.jozufozu.flywheel.lib.task;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.backend.task.ParallelTaskExecutor;
import it.unimi.dsi.fastutil.ints.IntArrayList;
class PlanExecutionTest {
protected static final ParallelTaskExecutor EXECUTOR = new ParallelTaskExecutor("PlanTest");
@BeforeAll
public static void setUp() {
EXECUTOR.startWorkers();
}
@AfterAll
public static void tearDown() {
EXECUTOR.stopWorkers();
}
@ParameterizedTest
@ValueSource(ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
void testSynchronizer(int countDown) {
var done = new AtomicBoolean(false);
var synchronizer = new Synchronizer(countDown, () -> done.set(true));
for (int i = 0; i < countDown - 1; i++) {
synchronizer.decrementAndEventuallyRun();
Assertions.assertFalse(done.get(), "Done early at " + i);
}
synchronizer.decrementAndEventuallyRun();
Assertions.assertTrue(done.get());
}
@ParameterizedTest
@ValueSource(ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
void simpleBarrierSequencing(int barriers) {
var sequence = new IntArrayList(barriers + 1);
var expected = new IntArrayList(barriers + 1);
var plan = SimplePlan.of(() -> sequence.add(1));
expected.add(1);
for (int i = 0; i < barriers; i++) {
final int sequenceNum = i + 2;
expected.add(sequenceNum);
plan = plan.then(SimplePlan.of(() -> sequence.add(sequenceNum)));
}
runAndWait(plan);
Assertions.assertEquals(expected, sequence);
}
@RepeatedTest(10)
void wideBarrierSequencing() {
var lock = new Object();
var sequence = new IntArrayList(8);
Runnable addOne = () -> {
synchronized (lock) {
sequence.add(1);
}
};
Runnable addTwo = () -> {
synchronized (lock) {
sequence.add(2);
}
};
var plan = SimplePlan.of(addOne, addOne, addOne, addOne)
.then(SimplePlan.of(addTwo, addTwo, addTwo, addTwo));
runAndWait(plan);
assertExpectedSequence(sequence, 1, 1, 1, 1, 2, 2, 2, 2);
}
@Test
void simpleNestedPlan() {
var sequence = new IntArrayList(2);
var plan = NestedPlan.of(SimplePlan.of(() -> sequence.add(1)));
runAndWait(plan);
assertExpectedSequence(sequence, 1);
}
@Test
void manyNestedPlans() {
var counter = new AtomicInteger(0);
var count4 = NestedPlan.of(SimplePlan.of(counter::incrementAndGet, counter::incrementAndGet), SimplePlan.of(counter::incrementAndGet, counter::incrementAndGet));
runAndWait(count4);
Assertions.assertEquals(4, counter.get());
counter.set(0);
var count8Barrier = NestedPlan.of(count4, count4);
runAndWait(count8Barrier);
Assertions.assertEquals(8, counter.get());
}
@Test
void unitPlan() {
var done = new AtomicBoolean(false);
UnitPlan.INSTANCE.execute(null, () -> done.set(true));
Assertions.assertTrue(done.get());
}
@Test
void emptyPlan() {
var done = new AtomicBoolean(false);
SimplePlan.of()
.execute(null, () -> done.set(true));
Assertions.assertTrue(done.get());
done.set(false);
NestedPlan.of()
.execute(null, () -> done.set(true));
Assertions.assertTrue(done.get());
}
@Test
void mainThreadPlan() {
var done = new AtomicBoolean(false);
var plan = new OnMainThreadPlan(() -> done.set(true));
plan.execute(EXECUTOR);
Assertions.assertFalse(done.get());
EXECUTOR.syncPoint();
Assertions.assertTrue(done.get());
}
private static void assertExpectedSequence(IntArrayList sequence, int... expected) {
Assertions.assertArrayEquals(expected, sequence.toIntArray());
}
public static void runAndWait(Plan plan) {
new TestBarrier(plan).runAndWait();
}
private static final class TestBarrier {
private final Plan plan;
private boolean done = false;
private TestBarrier(Plan plan) {
this.plan = plan;
}
public void runAndWait() {
plan.execute(EXECUTOR, this::doneWithPlan);
synchronized (this) {
// early exit in case the plan is already done for e.g. UnitPlan
if (done) {
return;
}
try {
wait();
} catch (InterruptedException ignored) {
// noop
}
}
}
public void doneWithPlan() {
synchronized (this) {
done = true;
notifyAll();
}
}
}
}

View file

@ -0,0 +1,109 @@
package com.jozufozu.flywheel.lib.task;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import com.jozufozu.flywheel.api.task.Plan;
public class PlanSimplificationTest {
public static final Runnable NOOP = () -> {
};
public static final Plan SIMPLE = SimplePlan.of(NOOP);
@Test
void emptyPlans() {
var empty = NestedPlan.of();
Assertions.assertEquals(empty.maybeSimplify(), UnitPlan.INSTANCE);
var simpleEmpty = SimplePlan.of();
Assertions.assertEquals(simpleEmpty.maybeSimplify(), UnitPlan.INSTANCE);
}
@Test
void nestedSimplePlans() {
var twoSimple = NestedPlan.of(SimplePlan.of(NOOP, NOOP, NOOP), SIMPLE);
Assertions.assertEquals(twoSimple.maybeSimplify(), SimplePlan.of(NOOP, NOOP, NOOP, NOOP));
var threeSimple = NestedPlan.of(SIMPLE, SIMPLE, SIMPLE);
Assertions.assertEquals(threeSimple.maybeSimplify(), SimplePlan.of(NOOP, NOOP, NOOP));
}
@Test
void oneNestedPlan() {
var oneSimple = NestedPlan.of(SIMPLE);
Assertions.assertEquals(oneSimple.maybeSimplify(), SIMPLE);
var mainThreadNoop = new OnMainThreadPlan(NOOP);
var oneMainThread = NestedPlan.of(mainThreadNoop);
Assertions.assertEquals(oneMainThread.maybeSimplify(), mainThreadNoop);
var barrier = new BarrierPlan(SIMPLE, SIMPLE);
var oneBarrier = NestedPlan.of(barrier);
Assertions.assertEquals(oneBarrier.maybeSimplify(), barrier);
}
@Test
void nestedNestedPlan() {
var outer = NestedPlan.of(SIMPLE);
var outermost = NestedPlan.of(outer);
Assertions.assertEquals(outermost.maybeSimplify(), SIMPLE);
}
@Test
void nestedUnitPlan() {
var onlyUnit = NestedPlan.of(UnitPlan.INSTANCE, UnitPlan.INSTANCE, UnitPlan.INSTANCE);
Assertions.assertEquals(onlyUnit.maybeSimplify(), UnitPlan.INSTANCE);
var unitAndSimple = NestedPlan.of(UnitPlan.INSTANCE, UnitPlan.INSTANCE, SIMPLE);
Assertions.assertEquals(unitAndSimple.maybeSimplify(), SIMPLE);
}
@Test
void complexNesting() {
var mainThreadNoop = new OnMainThreadPlan(NOOP);
var nested = NestedPlan.of(mainThreadNoop, SIMPLE);
Assertions.assertEquals(nested.maybeSimplify(), nested); // cannot simplify
var barrier = new BarrierPlan(SIMPLE, SIMPLE);
var complex = NestedPlan.of(barrier, nested);
Assertions.assertEquals(complex.maybeSimplify(), NestedPlan.of(barrier, mainThreadNoop, SIMPLE));
}
@Test
void nestedNoSimple() {
var mainThreadNoop = new OnMainThreadPlan(NOOP);
var barrier = new BarrierPlan(SIMPLE, SIMPLE);
var oneMainThread = NestedPlan.of(mainThreadNoop, NestedPlan.of(mainThreadNoop, barrier, barrier));
Assertions.assertEquals(oneMainThread.maybeSimplify(), NestedPlan.of(mainThreadNoop, mainThreadNoop, barrier, barrier));
}
@Test
void manyNestedButJustOneAfterSimplification() {
var barrier = new BarrierPlan(SIMPLE, SIMPLE);
var oneMainThread = NestedPlan.of(barrier, NestedPlan.of(UnitPlan.INSTANCE, UnitPlan.INSTANCE));
Assertions.assertEquals(oneMainThread.maybeSimplify(), barrier);
}
@Test
void barrierPlan() {
var doubleUnit = new BarrierPlan(UnitPlan.INSTANCE, UnitPlan.INSTANCE);
Assertions.assertEquals(doubleUnit.maybeSimplify(), UnitPlan.INSTANCE);
var simpleThenUnit = new BarrierPlan(SIMPLE, UnitPlan.INSTANCE);
Assertions.assertEquals(simpleThenUnit.maybeSimplify(), SIMPLE);
var unitThenSimple = new BarrierPlan(UnitPlan.INSTANCE, SIMPLE);
Assertions.assertEquals(unitThenSimple.maybeSimplify(), SIMPLE);
var simpleThenSimple = new BarrierPlan(SIMPLE, SIMPLE);
Assertions.assertEquals(simpleThenSimple.maybeSimplify(), new BarrierPlan(SIMPLE, SIMPLE));
}
}