package io.trino.execution;

import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.airlift.concurrent.Threads;
import io.airlift.stats.CounterStat;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.metadata.ExchangeHandleResolver;
import io.trino.operator.DriverContext;
import io.trino.operator.OperatorContext;
import io.trino.operator.PipelineContext;
import io.trino.operator.TaskContext;
import io.trino.spi.QueryId;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.plan.PlanNodeId;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/execution/TestMemoryRevokingScheduler.class */
public class TestMemoryRevokingScheduler {
    private final AtomicInteger idGeneator = new AtomicInteger();
    private final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(DataSize.of(10, DataSize.Unit.GIGABYTE));
    private final Map<QueryId, QueryContext> queryContexts = new HashMap();
    private ScheduledExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;
    private SqlTaskExecutionFactory sqlTaskExecutionFactory;
    private MemoryPool memoryPool;
    private Set<OperatorContext> allOperatorContexts;

    @BeforeMethod
    public void setUp() {
        this.memoryPool = new MemoryPool(DataSize.ofBytes(10L));
        TaskExecutor taskExecutor = new TaskExecutor(8, 16, 3, 4, Ticker.systemTicker());
        taskExecutor.start();
        this.executor = Executors.newScheduledThreadPool(1, Threads.threadsNamed("task-notification-%s"));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.threadsNamed("task-notification-%s"));
        this.sqlTaskExecutionFactory = new SqlTaskExecutionFactory(this.executor, taskExecutor, TaskTestUtils.createTestingPlanner(), TaskTestUtils.createTestSplitMonitor(), new TaskManagerConfig());
        this.allOperatorContexts = null;
    }

    @AfterMethod(alwaysRun = true)
    public void tearDown() {
        this.queryContexts.clear();
        this.memoryPool = null;
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
    }

    @Test
    public void testScheduleMemoryRevoking() throws Exception {
        QueryContext orCreateQueryContext = getOrCreateQueryContext(new QueryId("q1"));
        QueryContext orCreateQueryContext2 = getOrCreateQueryContext(new QueryId("q2"));
        SqlTask newSqlTask = newSqlTask(orCreateQueryContext.getQueryId());
        SqlTask newSqlTask2 = newSqlTask(orCreateQueryContext2.getQueryId());
        PipelineContext addPipelineContext = getOrCreateTaskContext(newSqlTask).addPipelineContext(0, false, false, false);
        DriverContext addDriverContext = addPipelineContext.addDriverContext();
        OperatorContext addOperatorContext = addDriverContext.addOperatorContext(1, new PlanNodeId("na"), "na");
        OperatorContext addOperatorContext2 = addDriverContext.addOperatorContext(2, new PlanNodeId("na"), "na");
        OperatorContext addOperatorContext3 = addPipelineContext.addDriverContext().addOperatorContext(3, new PlanNodeId("na"), "na");
        DriverContext addDriverContext2 = getOrCreateTaskContext(newSqlTask2).addPipelineContext(1, false, false, false).addDriverContext();
        OperatorContext addOperatorContext4 = addDriverContext2.addOperatorContext(4, new PlanNodeId("na"), "na");
        OperatorContext addOperatorContext5 = addDriverContext2.addOperatorContext(5, new PlanNodeId("na"), "na");
        ImmutableList of = ImmutableList.of(newSqlTask, newSqlTask2);
        MemoryRevokingScheduler memoryRevokingScheduler = new MemoryRevokingScheduler(this.memoryPool, () -> {
            return of;
        }, this.executor, 1.0d, 1.0d);
        this.allOperatorContexts = ImmutableSet.of(addOperatorContext, addOperatorContext2, addOperatorContext3, addOperatorContext4, addOperatorContext5);
        assertMemoryRevokingNotRequested();
        requestMemoryRevoking(memoryRevokingScheduler);
        Assert.assertEquals(10L, this.memoryPool.getFreeBytes());
        assertMemoryRevokingNotRequested();
        LocalMemoryContext localRevocableMemoryContext = addOperatorContext.localRevocableMemoryContext();
        LocalMemoryContext localRevocableMemoryContext2 = addOperatorContext3.localRevocableMemoryContext();
        LocalMemoryContext localRevocableMemoryContext3 = addOperatorContext4.localRevocableMemoryContext();
        LocalMemoryContext localRevocableMemoryContext4 = addOperatorContext5.localRevocableMemoryContext();
        localRevocableMemoryContext.setBytes(3L);
        localRevocableMemoryContext2.setBytes(6L);
        Assert.assertEquals(1L, this.memoryPool.getFreeBytes());
        requestMemoryRevoking(memoryRevokingScheduler);
        assertMemoryRevokingNotRequested();
        localRevocableMemoryContext3.setBytes(7L);
        Assert.assertEquals(-6L, this.memoryPool.getFreeBytes());
        requestMemoryRevoking(memoryRevokingScheduler);
        assertMemoryRevokingRequestedFor(addOperatorContext, addOperatorContext3);
        requestMemoryRevoking(memoryRevokingScheduler);
        assertMemoryRevokingRequestedFor(addOperatorContext, addOperatorContext3);
        localRevocableMemoryContext.setBytes(0L);
        addOperatorContext.resetMemoryRevokingRequested();
        requestMemoryRevoking(memoryRevokingScheduler);
        assertMemoryRevokingRequestedFor(addOperatorContext3);
        Assert.assertEquals(-3L, this.memoryPool.getFreeBytes());
        localRevocableMemoryContext4.setBytes(3L);
        Assert.assertEquals(-6L, this.memoryPool.getFreeBytes());
        requestMemoryRevoking(memoryRevokingScheduler);
        assertMemoryRevokingRequestedFor(addOperatorContext3);
        localRevocableMemoryContext4.setBytes(4L);
        Assert.assertEquals(-7L, this.memoryPool.getFreeBytes());
        requestMemoryRevoking(memoryRevokingScheduler);
        assertMemoryRevokingRequestedFor(addOperatorContext3, addOperatorContext4);
    }

    @Test
    public void testImmediateMemoryRevoking() throws Exception {
        SqlTask newSqlTask = newSqlTask(new QueryId("query"));
        OperatorContext createContexts = createContexts(newSqlTask);
        this.allOperatorContexts = ImmutableSet.of(createContexts);
        ImmutableList of = ImmutableList.of(newSqlTask);
        new MemoryRevokingScheduler(this.memoryPool, () -> {
            return of;
        }, this.executor, 1.0d, 1.0d).registerPoolListeners();
        createContexts.localRevocableMemoryContext().setBytes(12L);
        awaitAsynchronousCallbacksRun();
        assertMemoryRevokingRequestedFor(createContexts);
    }

    private OperatorContext createContexts(SqlTask sqlTask) {
        return getOrCreateTaskContext(sqlTask).addPipelineContext(0, false, false, false).addDriverContext().addOperatorContext(1, new PlanNodeId("na"), "na");
    }

    private void requestMemoryRevoking(MemoryRevokingScheduler memoryRevokingScheduler) throws Exception {
        memoryRevokingScheduler.requestMemoryRevokingIfNeeded();
        awaitAsynchronousCallbacksRun();
    }

    private void awaitAsynchronousCallbacksRun() throws Exception {
        this.executor.invokeAll(Collections.singletonList(() -> {
            return null;
        }));
    }

    private void assertMemoryRevokingRequestedFor(OperatorContext... operatorContextArr) {
        ImmutableSet copyOf = ImmutableSet.copyOf(operatorContextArr);
        copyOf.forEach(operatorContext -> {
            Assert.assertTrue(operatorContext.isMemoryRevokingRequested(), "expected memory requested for operator " + operatorContext.getOperatorId());
        });
        Sets.difference(this.allOperatorContexts, copyOf).forEach(operatorContext2 -> {
            Assert.assertFalse(operatorContext2.isMemoryRevokingRequested(), "expected memory  not requested for operator " + operatorContext2.getOperatorId());
        });
    }

    private void assertMemoryRevokingNotRequested() {
        assertMemoryRevokingRequestedFor(new OperatorContext[0]);
    }

    private SqlTask newSqlTask(QueryId queryId) {
        QueryContext orCreateQueryContext = getOrCreateQueryContext(queryId);
        TaskId taskId = new TaskId(new StageId(queryId.getId(), 0), this.idGeneator.incrementAndGet(), 0);
        return SqlTask.createSqlTask(taskId, URI.create("fake://task/" + taskId), "fake", orCreateQueryContext, this.sqlTaskExecutionFactory, this.executor, sqlTask -> {
        }, DataSize.of(32L, DataSize.Unit.MEGABYTE), DataSize.of(200L, DataSize.Unit.MEGABYTE), new ExchangeManagerRegistry(new ExchangeHandleResolver()), new CounterStat());
    }

    private QueryContext getOrCreateQueryContext(QueryId queryId) {
        return this.queryContexts.computeIfAbsent(queryId, queryId2 -> {
            return new QueryContext(queryId2, DataSize.of(1L, DataSize.Unit.MEGABYTE), this.memoryPool, new TestingGcMonitor(), this.executor, this.scheduledExecutor, DataSize.of(1L, DataSize.Unit.GIGABYTE), this.spillSpaceTracker);
        });
    }

    private TaskContext getOrCreateTaskContext(SqlTask sqlTask) {
        if (sqlTask.getTaskContext().isEmpty()) {
            TaskTestUtils.updateTask(sqlTask, ImmutableList.of(), OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED).withBuffer(TestSqlTask.OUT, 0).withNoMoreBufferIds());
        }
        return (TaskContext) sqlTask.getTaskContext().orElseThrow(() -> {
            return new IllegalStateException("TaskContext not present");
        });
    }
}
