package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.log.Logger;
import io.airlift.stats.Distribution;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.execution.StateMachine;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.operator.OperatorStats;
import io.trino.operator.PipelineStats;
import io.trino.operator.TaskStats;
import io.trino.spi.eventlistener.StageGcStatistics;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.util.Failures;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.concurrent.ThreadSafe;
import org.joda.time.DateTime;

@ThreadSafe
/* loaded from: input_file:io/trino/execution/StageStateMachine.class */
public class StageStateMachine {
    private static final Logger log = Logger.get(StageStateMachine.class);
    private final StageId stageId;
    private final PlanFragment fragment;
    private final Session session;
    private final Map<PlanNodeId, TableInfo> tables;
    private final SplitSchedulerStats scheduledStats;
    private final StateMachine<StageState> stageState;
    private final StateMachine<Optional<StageInfo>> finalStageInfo;
    private final AtomicReference<ExecutionFailureInfo> failureCause = new AtomicReference<>();
    private final AtomicReference<DateTime> schedulingComplete = new AtomicReference<>();
    private final Distribution getSplitDistribution = new Distribution();
    private final AtomicLong peakUserMemory = new AtomicLong();
    private final AtomicLong peakRevocableMemory = new AtomicLong();
    private final AtomicLong currentUserMemory = new AtomicLong();
    private final AtomicLong currentRevocableMemory = new AtomicLong();
    private final AtomicLong currentTotalMemory = new AtomicLong();

    public StageStateMachine(StageId stageId, Session session, PlanFragment planFragment, Map<PlanNodeId, TableInfo> map, Executor executor, SplitSchedulerStats splitSchedulerStats) {
        this.stageId = (StageId) Objects.requireNonNull(stageId, "stageId is null");
        this.session = (Session) Objects.requireNonNull(session, "session is null");
        this.fragment = (PlanFragment) Objects.requireNonNull(planFragment, "fragment is null");
        this.tables = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "tables is null"));
        this.scheduledStats = (SplitSchedulerStats) Objects.requireNonNull(splitSchedulerStats, "schedulerStats is null");
        this.stageState = new StateMachine<>("stage " + stageId, executor, StageState.PLANNED, StageState.TERMINAL_STAGE_STATES);
        this.stageState.addStateChangeListener(stageState -> {
            log.debug("Stage %s is %s", new Object[]{stageId, stageState});
        });
        this.finalStageInfo = new StateMachine<>("final stage " + stageId, executor, Optional.empty());
    }

    public StageId getStageId() {
        return this.stageId;
    }

    public Session getSession() {
        return this.session;
    }

    public StageState getState() {
        return this.stageState.get();
    }

    public PlanFragment getFragment() {
        return this.fragment;
    }

    public void addStateChangeListener(StateMachine.StateChangeListener<StageState> stateChangeListener) {
        this.stageState.addStateChangeListener(stateChangeListener);
    }

    public synchronized boolean transitionToScheduling() {
        return this.stageState.compareAndSet(StageState.PLANNED, StageState.SCHEDULING);
    }

    public synchronized boolean transitionToSchedulingSplits() {
        return this.stageState.setIf(StageState.SCHEDULING_SPLITS, stageState -> {
            return stageState == StageState.PLANNED || stageState == StageState.SCHEDULING;
        });
    }

    public synchronized boolean transitionToScheduled() {
        this.schedulingComplete.compareAndSet(null, DateTime.now());
        return this.stageState.setIf(StageState.SCHEDULED, stageState -> {
            return stageState == StageState.PLANNED || stageState == StageState.SCHEDULING || stageState == StageState.SCHEDULING_SPLITS;
        });
    }

    public boolean transitionToRunning() {
        return this.stageState.setIf(StageState.RUNNING, stageState -> {
            return (stageState == StageState.RUNNING || stageState == StageState.FLUSHING || stageState.isDone()) ? false : true;
        });
    }

    public boolean transitionToFlushing() {
        return this.stageState.setIf(StageState.FLUSHING, stageState -> {
            return (stageState == StageState.FLUSHING || stageState.isDone()) ? false : true;
        });
    }

    public boolean transitionToFinished() {
        return this.stageState.setIf(StageState.FINISHED, stageState -> {
            return !stageState.isDone();
        });
    }

    public boolean transitionToCanceled() {
        return this.stageState.setIf(StageState.CANCELED, stageState -> {
            return !stageState.isDone();
        });
    }

    public boolean transitionToAborted() {
        return this.stageState.setIf(StageState.ABORTED, stageState -> {
            return !stageState.isDone();
        });
    }

    public boolean transitionToFailed(Throwable th) {
        Objects.requireNonNull(th, "throwable is null");
        this.failureCause.compareAndSet(null, Failures.toFailure(th));
        boolean z = this.stageState.setIf(StageState.FAILED, stageState -> {
            return !stageState.isDone();
        });
        if (z) {
            log.error(th, "Stage %s failed", new Object[]{this.stageId});
        } else {
            log.debug(th, "Failure after stage %s finished", new Object[]{this.stageId});
        }
        return z;
    }

    public void addFinalStageInfoListener(StateMachine.StateChangeListener<StageInfo> stateChangeListener) {
        AtomicBoolean atomicBoolean = new AtomicBoolean();
        this.finalStageInfo.addStateChangeListener(optional -> {
            if (optional.isPresent() && atomicBoolean.compareAndSet(false, true)) {
                stateChangeListener.stateChanged((StageInfo) optional.get());
            }
        });
    }

    public void setAllTasksFinal(Iterable<TaskInfo> iterable) {
        Objects.requireNonNull(iterable, "finalTaskInfos is null");
        Preconditions.checkState(this.stageState.get().isDone());
        StageInfo stageInfo = getStageInfo(() -> {
            return iterable;
        });
        Preconditions.checkArgument(stageInfo.isCompleteInfo(), "finalTaskInfos are not all done");
        this.finalStageInfo.compareAndSet(Optional.empty(), Optional.of(stageInfo));
    }

    public long getUserMemoryReservation() {
        return this.currentUserMemory.get();
    }

    public long getTotalMemoryReservation() {
        return this.currentTotalMemory.get();
    }

    public void updateMemoryUsage(long j, long j2, long j3) {
        this.currentUserMemory.addAndGet(j);
        this.currentRevocableMemory.addAndGet(j2);
        this.currentTotalMemory.addAndGet(j3);
        this.peakUserMemory.updateAndGet(j4 -> {
            return Math.max(this.currentUserMemory.get(), j4);
        });
        this.peakRevocableMemory.updateAndGet(j5 -> {
            return Math.max(this.currentRevocableMemory.get(), j5);
        });
    }

    public BasicStageStats getBasicStageStats(Supplier<Iterable<TaskInfo>> supplier) {
        Optional<StageInfo> optional = this.finalStageInfo.get();
        if (optional.isPresent()) {
            return optional.get().getStageStats().toBasicStageStats(optional.get().getState());
        }
        StageState stageState = this.stageState.get();
        boolean z = stageState == StageState.RUNNING || stageState == StageState.FLUSHING || stageState.isDone();
        ImmutableList<TaskInfo> copyOf = ImmutableList.copyOf(supplier.get());
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        long j5 = 0;
        long j6 = 0;
        long j7 = 0;
        long j8 = 0;
        long j9 = 0;
        long j10 = 0;
        long j11 = 0;
        long j12 = 0;
        long j13 = 0;
        boolean z2 = true;
        HashSet hashSet = new HashSet();
        for (TaskInfo taskInfo : copyOf) {
            TaskState state = taskInfo.getTaskStatus().getState();
            TaskStats stats = taskInfo.getStats();
            i += stats.getTotalDrivers();
            i2 += stats.getQueuedDrivers();
            i3 += stats.getRunningDrivers();
            i4 += stats.getCompletedDrivers();
            j = (long) (j + stats.getCumulativeUserMemory());
            j2 = (long) (j2 + stats.getCumulativeSystemMemory());
            long bytes = stats.getUserMemoryReservation().toBytes();
            j3 += bytes;
            j4 += bytes + stats.getSystemMemoryReservation().toBytes() + stats.getRevocableMemoryReservation().toBytes();
            j5 += stats.getTotalScheduledTime().roundTo(TimeUnit.NANOSECONDS);
            j6 += stats.getTotalCpuTime().roundTo(TimeUnit.NANOSECONDS);
            if (!state.isDone()) {
                z2 &= stats.isFullyBlocked();
                hashSet.addAll(stats.getBlockedReasons());
            }
            j7 += stats.getPhysicalInputDataSize().toBytes();
            j8 += stats.getPhysicalInputPositions();
            j9 += stats.getPhysicalInputReadTime().roundTo(TimeUnit.NANOSECONDS);
            j10 += stats.getInternalNetworkInputDataSize().toBytes();
            j11 += stats.getInternalNetworkInputPositions();
            Stream<PlanNode> stream = this.fragment.getPartitionedSourceNodes().stream();
            Class<TableScanNode> cls = TableScanNode.class;
            Objects.requireNonNull(TableScanNode.class);
            if (stream.anyMatch((v1) -> {
                return r1.isInstance(v1);
            })) {
                j12 += stats.getRawInputDataSize().toBytes();
                j13 += stats.getRawInputPositions();
            }
        }
        OptionalDouble empty = OptionalDouble.empty();
        if (z && i != 0) {
            empty = OptionalDouble.of(Math.min(100.0d, (i4 * 100.0d) / i));
        }
        return new BasicStageStats(z, i, i2, i3, i4, DataSize.succinctBytes(j7), j8, new Duration(j9, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), DataSize.succinctBytes(j10), j11, DataSize.succinctBytes(j12), j13, j, j2, DataSize.succinctBytes(j3), DataSize.succinctBytes(j4), new Duration(j6, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(j5, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), z2, hashSet, empty);
    }

    public StageInfo getStageInfo(Supplier<Iterable<TaskInfo>> supplier) {
        Optional<StageInfo> optional = this.finalStageInfo.get();
        if (optional.isPresent()) {
            return optional.get();
        }
        StageState stageState = this.stageState.get();
        ImmutableList<TaskInfo> copyOf = ImmutableList.copyOf(supplier.get());
        int size = copyOf.size();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        long j5 = 0;
        long j6 = this.peakUserMemory.get();
        long j7 = this.peakRevocableMemory.get();
        long j8 = 0;
        long j9 = 0;
        long j10 = 0;
        long j11 = 0;
        long j12 = 0;
        long j13 = 0;
        long j14 = 0;
        long j15 = 0;
        long j16 = 0;
        long j17 = 0;
        long j18 = 0;
        long j19 = 0;
        long j20 = 0;
        long j21 = 0;
        long j22 = 0;
        long j23 = 0;
        int i8 = 0;
        int i9 = 0;
        int i10 = 0;
        int i11 = 0;
        int i12 = 0;
        boolean z = true;
        HashSet hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        for (TaskInfo taskInfo : copyOf) {
            TaskState state = taskInfo.getTaskStatus().getState();
            if (state.isDone()) {
                i2++;
            } else {
                i++;
            }
            TaskStats stats = taskInfo.getStats();
            i3 += stats.getTotalDrivers();
            i4 += stats.getQueuedDrivers();
            i5 += stats.getRunningDrivers();
            i6 += stats.getBlockedDrivers();
            i7 += stats.getCompletedDrivers();
            j = (long) (j + stats.getCumulativeUserMemory());
            j2 = (long) (j2 + stats.getCumulativeSystemMemory());
            long bytes = stats.getUserMemoryReservation().toBytes();
            long bytes2 = stats.getSystemMemoryReservation().toBytes();
            long bytes3 = stats.getRevocableMemoryReservation().toBytes();
            j3 += bytes;
            j4 += bytes3;
            j5 += bytes + bytes2 + bytes3;
            j8 += stats.getTotalScheduledTime().roundTo(TimeUnit.NANOSECONDS);
            j9 += stats.getTotalCpuTime().roundTo(TimeUnit.NANOSECONDS);
            j10 += stats.getTotalBlockedTime().roundTo(TimeUnit.NANOSECONDS);
            if (!state.isDone()) {
                z &= stats.isFullyBlocked();
                hashSet.addAll(stats.getBlockedReasons());
            }
            j11 += stats.getPhysicalInputDataSize().toBytes();
            j12 += stats.getPhysicalInputPositions();
            j13 += stats.getPhysicalInputReadTime().roundTo(TimeUnit.NANOSECONDS);
            j14 += stats.getInternalNetworkInputDataSize().toBytes();
            j15 += stats.getInternalNetworkInputPositions();
            j16 += stats.getRawInputDataSize().toBytes();
            j17 += stats.getRawInputPositions();
            j18 += stats.getProcessedInputDataSize().toBytes();
            j19 += stats.getProcessedInputPositions();
            j20 += taskInfo.getOutputBuffers().getTotalBufferedBytes();
            j21 += stats.getOutputDataSize().toBytes();
            j22 += stats.getOutputPositions();
            j23 += stats.getPhysicalWrittenDataSize().toBytes();
            i8 += stats.getFullGcCount();
            i9 += stats.getFullGcCount() > 0 ? 1 : 0;
            int intExact = Math.toIntExact(stats.getFullGcTime().roundTo(TimeUnit.SECONDS));
            i12 += intExact;
            i10 = Math.min(i10, intExact);
            i11 = Math.max(i11, intExact);
            for (PipelineStats pipelineStats : stats.getPipelines()) {
                for (OperatorStats operatorStats : pipelineStats.getOperatorSummaries()) {
                    hashMap.compute(pipelineStats.getPipelineId() + "." + operatorStats.getOperatorId(), (str, operatorStats2) -> {
                        return operatorStats2 == null ? operatorStats : operatorStats2.add(operatorStats);
                    });
                }
            }
        }
        return new StageInfo(this.stageId, stageState, this.fragment, this.fragment.getTypes(), new StageStats(this.schedulingComplete.get(), this.getSplitDistribution.snapshot(), size, i, i2, i3, i4, i5, i6, i7, j, j2, DataSize.succinctBytes(j3), DataSize.succinctBytes(j4), DataSize.succinctBytes(j5), DataSize.succinctBytes(j6), DataSize.succinctBytes(j7), Duration.succinctDuration(j8, TimeUnit.NANOSECONDS), Duration.succinctDuration(j9, TimeUnit.NANOSECONDS), Duration.succinctDuration(j10, TimeUnit.NANOSECONDS), z && i > 0, hashSet, DataSize.succinctBytes(j11), j12, Duration.succinctDuration(j13, TimeUnit.NANOSECONDS), DataSize.succinctBytes(j14), j15, DataSize.succinctBytes(j16), j17, DataSize.succinctBytes(j18), j19, DataSize.succinctBytes(j20), DataSize.succinctBytes(j21), j22, DataSize.succinctBytes(j23), new StageGcStatistics(this.stageId.getId(), size, i9, i10, i11, i12, (int) ((1.0d * i12) / i8)), ImmutableList.copyOf(hashMap.values())), copyOf, ImmutableList.of(), this.tables, stageState == StageState.FAILED ? this.failureCause.get() : null);
    }

    public void recordGetSplitTime(long j) {
        long nanoTime = System.nanoTime() - j;
        this.getSplitDistribution.add(nanoTime);
        this.scheduledStats.getGetSplitTime().add(nanoTime, TimeUnit.NANOSECONDS);
    }

    public String toString() {
        return MoreObjects.toStringHelper(this).add("stageId", this.stageId).add("stageState", this.stageState).toString();
    }
}
