/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink.tiered;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
import org.apache.celeborn.plugin.flink.client.CelebornBufferStream;
import org.apache.celeborn.plugin.flink.client.FlinkShuffleClientImpl;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;
import org.apache.celeborn.plugin.flink.tiered.CelebornChannelBufferManager;
import org.apache.celeborn.plugin.flink.utils.Utils;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageInputChannelId;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CelebornChannelBufferReader {
    private static final Logger LOG = LoggerFactory.getLogger(CelebornChannelBufferReader.class);
    private CelebornChannelBufferManager bufferManager;
    private final FlinkShuffleClientImpl client;
    private final int shuffleId;
    private final int partitionId;
    private final TieredStorageInputChannelId inputChannelId;
    private final int subPartitionIndexStart;
    private final int subPartitionIndexEnd;
    private final BiConsumer<ByteBuf, TieredStorageSubpartitionId> dataListener;
    private final BiConsumer<Throwable, TieredStorageSubpartitionId> failureListener;
    private final Consumer<RequestMessage> messageConsumer;
    private CelebornBufferStream bufferStream;
    private boolean isOpened;
    private volatile boolean closed = false;
    private volatile ConcurrentHashMap<Integer, Integer> subPartitionRequiredSegmentIds;
    private int numBackLog = 0;

    public CelebornChannelBufferReader(FlinkShuffleClientImpl client, ShuffleResourceDescriptor shuffleDescriptor, TieredStorageInputChannelId inputChannelId, int startSubIdx, int endSubIdx, BiConsumer<ByteBuf, TieredStorageSubpartitionId> dataListener, BiConsumer<Throwable, TieredStorageSubpartitionId> failureListener) {
        this.client = client;
        this.shuffleId = shuffleDescriptor.getShuffleId();
        this.partitionId = shuffleDescriptor.getPartitionId();
        this.inputChannelId = inputChannelId;
        this.subPartitionIndexStart = startSubIdx;
        this.subPartitionIndexEnd = endSubIdx;
        this.dataListener = dataListener;
        this.failureListener = failureListener;
        this.subPartitionRequiredSegmentIds = JavaUtils.newConcurrentHashMap();
        for (int subPartitionId = this.subPartitionIndexStart; subPartitionId <= this.subPartitionIndexEnd; ++subPartitionId) {
            this.subPartitionRequiredSegmentIds.put(subPartitionId, -1);
        }
        this.messageConsumer = requestMessage -> {
            if (requestMessage instanceof SubPartitionReadData) {
                this.dataReceived((SubPartitionReadData)requestMessage);
            } else if (requestMessage instanceof BacklogAnnouncement) {
                this.backlogReceived(((BacklogAnnouncement)requestMessage).getBacklog());
            } else if (requestMessage instanceof TransportableError) {
                this.errorReceived(((TransportableError)requestMessage).getErrorMessage());
            } else if (requestMessage instanceof BufferStreamEnd) {
                this.onStreamEnd((BufferStreamEnd)requestMessage);
            }
        };
    }

    public void setup(TieredStorageMemoryManager memoryManager) {
        this.bufferManager = new CelebornChannelBufferManager(memoryManager, this);
        if (this.numBackLog > 0) {
            int numRequestedBuffers = this.bufferManager.requestBuffers(this.numBackLog);
            if (numRequestedBuffers > 0) {
                this.bufferManager.decreaseRequiredCredits(numRequestedBuffers);
                this.notifyAvailableCredits(numRequestedBuffers);
            }
            this.numBackLog = 0;
        }
    }

    public void open(int initialCredit, boolean sync) {
        try {
            this.bufferStream = this.client.readBufferedPartition(this.shuffleId, this.partitionId, this.subPartitionIndexStart, this.subPartitionIndexEnd, true);
            this.bufferStream.open(this::requestBuffer, initialCredit, this.messageConsumer, sync);
        }
        catch (Exception e) {
            this.messageConsumer.accept(new TransportableError(0L, e));
            LOG.error("Failed to open reader", (Throwable)e);
        }
    }

    public void close() {
        if (this.closed) {
            return;
        }
        this.closed = true;
        if (!CelebornBufferStream.isEmptyStream(this.bufferStream)) {
            this.bufferStream.close();
            this.bufferStream = null;
        } else {
            LOG.warn("bufferStream is null when closed, shuffleId: {}, partitionId: {}", (Object)this.shuffleId, (Object)this.partitionId);
        }
        try {
            if (this.bufferManager != null) {
                this.bufferManager.close();
                this.bufferManager = null;
            }
        }
        catch (Throwable throwable) {
            LOG.warn("Failed to close buffer manager.", throwable);
        }
        this.subPartitionRequiredSegmentIds.clear();
        this.subPartitionRequiredSegmentIds = null;
    }

    public boolean isOpened() {
        return this.isOpened;
    }

    public void setOpened(boolean opened) {
        this.isOpened = opened;
    }

    boolean isClosed() {
        return this.closed;
    }

    public void notifyAvailableCredits(int numCredits) {
        if (numCredits <= 0) {
            return;
        }
        if (!this.closed && !CelebornBufferStream.isEmptyStream(this.bufferStream)) {
            this.bufferStream.addCredit(PbReadAddCredit.newBuilder().setStreamId(this.bufferStream.getStreamId()).setCredit(numCredits).build());
            return;
        }
        LOG.warn("The buffer stream is null or closed, ignore the credits for shuffleId: {}, partitionId: {}", (Object)this.shuffleId, (Object)this.partitionId);
    }

    public void notifyRequiredSegmentIfNeeded(int requiredSegmentId, int subPartitionId) {
        Integer lastRequiredSegmentId = this.subPartitionRequiredSegmentIds.computeIfAbsent(subPartitionId, id -> -1);
        if (requiredSegmentId >= 0 && requiredSegmentId != lastRequiredSegmentId) {
            LOG.debug("Notify required segment id {} for {} {}, the last segment id is {}", new Object[]{requiredSegmentId, this.partitionId, subPartitionId, lastRequiredSegmentId});
            this.subPartitionRequiredSegmentIds.put(subPartitionId, requiredSegmentId);
            if (!this.notifyRequiredSegment(requiredSegmentId, subPartitionId)) {
                this.subPartitionRequiredSegmentIds.putIfAbsent(subPartitionId, lastRequiredSegmentId);
            }
        }
    }

    public boolean notifyRequiredSegment(int requiredSegmentId, int subPartitionId) {
        this.subPartitionRequiredSegmentIds.put(subPartitionId, requiredSegmentId);
        if (!this.closed && !CelebornBufferStream.isEmptyStream(this.bufferStream)) {
            LOG.debug("Notify required segmentId {} for {} {} {}", new Object[]{requiredSegmentId, this.partitionId, subPartitionId, this.shuffleId});
            PbNotifyRequiredSegment notifyRequiredSegment = PbNotifyRequiredSegment.newBuilder().setStreamId(this.bufferStream.getStreamId()).setRequiredSegmentId(requiredSegmentId).setSubPartitionId(subPartitionId).build();
            this.bufferStream.notifyRequiredSegment(notifyRequiredSegment);
            return true;
        }
        return false;
    }

    public ByteBuf requestBuffer() {
        Buffer buffer = this.bufferManager.requestBuffer();
        return buffer == null ? null : buffer.asByteBuf();
    }

    public void backlogReceived(int backlog) {
        if (!this.closed) {
            if (this.bufferManager == null) {
                this.numBackLog += backlog;
                return;
            }
            int numRequestedBuffers = this.bufferManager.requestBuffers(backlog);
            if (numRequestedBuffers > 0) {
                this.bufferManager.decreaseRequiredCredits(numRequestedBuffers);
                this.notifyAvailableCredits(numRequestedBuffers);
            }
            this.numBackLog = 0;
            return;
        }
        LOG.warn("The buffer stream {} is null or closed, ignore the backlog for shuffleId: {}, partitionId: {}", new Object[]{this.bufferStream.getStreamId(), this.shuffleId, this.partitionId});
    }

    public void errorReceived(String errorMsg) {
        if (!this.closed) {
            this.closed = true;
            LOG.debug("Error received, " + errorMsg);
            if (!CelebornBufferStream.isEmptyStream(this.bufferStream) && this.bufferStream.getClient() != null) {
                LOG.error("Received error from {} message {}", (Object)NettyUtils.getRemoteAddress(this.bufferStream.getClient().getChannel()), (Object)errorMsg);
            }
            for (int subPartitionId = this.subPartitionIndexStart; subPartitionId <= this.subPartitionIndexEnd; ++subPartitionId) {
                this.failureListener.accept(new IOException(errorMsg), new TieredStorageSubpartitionId(subPartitionId));
            }
        }
    }

    public void dataReceived(SubPartitionReadData readData) {
        LOG.debug("Remote buffer stream reader get stream id {} subPartitionId {} received readable bytes {}.", new Object[]{readData.getStreamId(), readData.getSubPartitionId(), readData.getFlinkBuffer().readableBytes()});
        Utils.checkState(readData.getSubPartitionId() >= this.subPartitionIndexStart && readData.getSubPartitionId() <= this.subPartitionIndexEnd, "Wrong sub partition id: " + readData.getSubPartitionId());
        this.dataListener.accept(readData.getFlinkBuffer(), new TieredStorageSubpartitionId(readData.getSubPartitionId()));
        int numRequested = this.bufferManager.tryRequestBuffersIfNeeded();
        if (numRequested > 0) {
            this.bufferManager.decreaseRequiredCredits(numRequested);
            this.notifyAvailableCredits(numRequested);
        }
    }

    public void onStreamEnd(BufferStreamEnd streamEnd) {
        long streamId = streamEnd.getStreamId();
        LOG.debug("Buffer stream reader get stream end for {}", (Object)streamId);
        if (!this.closed && !CelebornBufferStream.isEmptyStream(this.bufferStream)) {
            this.bufferStream.moveToNextPartitionIfPossible(streamId, this::sendRequireSegmentId, true);
        }
    }

    public TieredStorageInputChannelId getInputChannelId() {
        return this.inputChannelId;
    }

    private void sendRequireSegmentId(long streamId, int subPartitionId) {
        if (this.subPartitionRequiredSegmentIds.containsKey(subPartitionId)) {
            int currentSegmentId = this.subPartitionRequiredSegmentIds.get(subPartitionId);
            if (this.bufferStream.isOpened() && currentSegmentId >= 0) {
                LOG.debug("Buffer stream {} is opened, notify required segment id {} ", (Object)streamId, (Object)currentSegmentId);
                this.notifyRequiredSegment(currentSegmentId, subPartitionId);
            }
        }
    }
}

