package io.trino.operator.window.matcher;

import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.window.matcher.Instruction;
import io.trino.operator.window.pattern.LabelEvaluator;
import io.trino.operator.window.pattern.MatchAggregation;
import io.trino.operator.window.pattern.PhysicalValueAccessor;
import io.trino.sql.planner.LocalExecutionPlanner;
import java.util.List;
import org.openjdk.jol.info.ClassLayout;

/* loaded from: input_file:io/trino/operator/window/matcher/Matcher.class */
public class Matcher {
    private final Program program;
    private final ThreadEquivalence threadEquivalence;
    private final List<MatchAggregation.MatchAggregationInstantiator> aggregations;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/window/matcher/Matcher$Runtime.class */
    public static class Runtime {
        private static final int INSTANCE_SIZE = ClassLayout.parseClass(Runtime.class).instanceSize();
        private final IntMultimap threadsAtInstructions;
        private final IntList threads;
        private final IntStack freeThreadIds;
        private int newThreadId;
        private final int inputLength;
        private final boolean matchingAtPartitionStart;
        private final Captures captures;
        private final MatchAggregations aggregations;

        public Runtime(Program program, int i, boolean z, List<MatchAggregation.MatchAggregationInstantiator> list, AggregatedMemoryContext aggregatedMemoryContext) {
            int size = 2 * program.size();
            this.threads = new IntList(size);
            this.freeThreadIds = new IntStack(size);
            this.captures = new Captures(size, program.getMinSlotCount(), program.getMinLabelCount());
            this.inputLength = i;
            this.matchingAtPartitionStart = z;
            this.aggregations = new MatchAggregations(size, list, aggregatedMemoryContext);
            this.threadsAtInstructions = new IntMultimap(program.size(), program.size());
        }

        private int forkThread(int i) {
            int newThread = newThread();
            this.captures.copy(i, newThread);
            this.aggregations.copy(i, newThread);
            return newThread;
        }

        private int newThread() {
            if (this.freeThreadIds.size() > 0) {
                return this.freeThreadIds.pop();
            }
            int i = this.newThreadId;
            this.newThreadId = i + 1;
            return i;
        }

        private void killThread(int i) {
            this.freeThreadIds.push(i);
            this.captures.release(i);
            this.aggregations.release(i);
        }

        private long getSizeInBytes() {
            return INSTANCE_SIZE + this.threadsAtInstructions.getSizeInBytes() + this.threads.getSizeInBytes() + this.freeThreadIds.getSizeInBytes() + this.captures.getSizeInBytes() + this.aggregations.getSizeInBytes();
        }
    }

    public Matcher(Program program, List<List<PhysicalValueAccessor>> list, List<LocalExecutionPlanner.MatchAggregationLabelDependency> list2, List<MatchAggregation.MatchAggregationInstantiator> list3) {
        this.program = program;
        this.threadEquivalence = new ThreadEquivalence(program, list, list2);
        this.aggregations = list3;
    }

    public MatchResult run(LabelEvaluator labelEvaluator, LocalMemoryContext localMemoryContext, AggregatedMemoryContext aggregatedMemoryContext) {
        IntList intList = new IntList(this.program.size());
        IntList intList2 = new IntList(this.program.size());
        int inputLength = labelEvaluator.getInputLength();
        Runtime runtime = new Runtime(this.program, inputLength, labelEvaluator.isMatchingAtPartitionStart(), this.aggregations, aggregatedMemoryContext);
        advanceAndSchedule(intList, runtime.newThread(), 0, 0, runtime);
        MatchResult matchResult = MatchResult.NO_MATCH;
        for (int i = 0; i < inputLength && intList.size() != 0; i++) {
            boolean z = false;
            runtime.threadsAtInstructions.clear();
            int i2 = 0;
            while (true) {
                if (i2 < intList.size()) {
                    int i3 = intList.get(i2);
                    int i4 = runtime.threads.get(i3);
                    Instruction at = this.program.at(i4);
                    switch (at.type()) {
                        case MATCH_LABEL:
                            runtime.captures.saveLabel(i3, ((MatchLabel) at).getLabel());
                            if (labelEvaluator.evaluateLabel(runtime.captures.getLabels(i3), runtime.aggregations.get(i3))) {
                                advanceAndSchedule(intList2, i3, i4 + 1, i + 1, runtime);
                                break;
                            } else {
                                runtime.killThread(i3);
                                break;
                            }
                        case DONE:
                            z = true;
                            matchResult = new MatchResult(true, runtime.captures.getLabels(i3), runtime.captures.getCaptures(i3));
                            runtime.killThread(i3);
                            break;
                        default:
                            throw new UnsupportedOperationException("not yet implemented");
                    }
                    if (z) {
                        for (int i5 = i2 + 1; i5 < intList.size(); i5++) {
                            runtime.killThread(intList.get(i5));
                        }
                    } else {
                        i2++;
                    }
                }
            }
            localMemoryContext.setBytes(runtime.getSizeInBytes() + intList.getSizeInBytes() + intList2.getSizeInBytes());
            IntList intList3 = intList;
            intList3.clear();
            intList = intList2;
            intList2 = intList3;
        }
        int i6 = 0;
        while (true) {
            if (i6 < intList.size()) {
                int i7 = intList.get(i6);
                if (this.program.at(runtime.threads.get(i7)).type() == Instruction.Type.DONE) {
                    matchResult = new MatchResult(true, runtime.captures.getLabels(i7), runtime.captures.getCaptures(i7));
                } else {
                    i6++;
                }
            }
        }
        return matchResult;
    }

    private void advanceAndSchedule(IntList intList, int i, int i2, int i3, Runtime runtime) {
        ArrayView arrayView = runtime.threadsAtInstructions.getArrayView(i2);
        for (int i4 = 0; i4 < arrayView.length(); i4++) {
            int i5 = arrayView.get(i4);
            if (this.threadEquivalence.equivalent(i5, runtime.captures.getLabels(i5), runtime.aggregations.get(i5), i, runtime.captures.getLabels(i), runtime.aggregations.get(i), i2)) {
                runtime.killThread(i);
                return;
            }
        }
        runtime.threadsAtInstructions.add(i2, i);
        Instruction at = this.program.at(i2);
        switch (at.type()) {
            case MATCH_START:
                if (i3 == 0 && runtime.matchingAtPartitionStart) {
                    advanceAndSchedule(intList, i, i2 + 1, i3, runtime);
                    return;
                } else {
                    runtime.killThread(i);
                    return;
                }
            case MATCH_END:
                if (i3 == runtime.inputLength) {
                    advanceAndSchedule(intList, i, i2 + 1, i3, runtime);
                    return;
                } else {
                    runtime.killThread(i);
                    return;
                }
            case JUMP:
                advanceAndSchedule(intList, i, ((Jump) at).getTarget(), i3, runtime);
                return;
            case SPLIT:
                int forkThread = runtime.forkThread(i);
                advanceAndSchedule(intList, i, ((Split) at).getFirst(), i3, runtime);
                advanceAndSchedule(intList, forkThread, ((Split) at).getSecond(), i3, runtime);
                return;
            case SAVE:
                runtime.captures.save(i, i3);
                advanceAndSchedule(intList, i, i2 + 1, i3, runtime);
                return;
            default:
                runtime.threads.set(i, i2);
                intList.add(i);
                return;
        }
    }
}
