/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.client.websocket;

import com.linecorp.armeria.client.ClientOptions;
import com.linecorp.armeria.client.ClientRequestContext;
import com.linecorp.armeria.client.ClientRequestContextCaptor;
import com.linecorp.armeria.client.Clients;
import com.linecorp.armeria.client.RequestOptions;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.client.endpoint.EndpointGroup;
import com.linecorp.armeria.client.websocket.WebSocketClient;
import com.linecorp.armeria.client.websocket.WebSocketClientFrameDecoder;
import com.linecorp.armeria.client.websocket.WebSocketClientHandshakeException;
import com.linecorp.armeria.client.websocket.WebSocketSession;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RequestHeadersBuilder;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.Scheme;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.logging.RequestLogProperty;
import com.linecorp.armeria.common.stream.ByteStreamMessage;
import com.linecorp.armeria.common.stream.StreamMessage;
import com.linecorp.armeria.internal.client.ClientUtil;
import com.linecorp.armeria.internal.common.DefaultSplitHttpResponse;
import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder;
import com.linecorp.armeria.internal.common.websocket.WebSocketUtil;
import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper;
import com.linecorp.armeria.internal.shaded.guava.base.Joiner;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.util.concurrent.EventExecutor;
import java.net.URI;
import java.util.Base64;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;

final class DefaultWebSocketClient
implements WebSocketClient {
    static final WebSocketClient DEFAULT = WebSocketClient.of(ClientUtil.UNDEFINED_URI);
    private static final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(true);
    private final WebClient webClient;
    private final int maxFramePayloadLength;
    private final boolean allowMaskMismatch;
    private final List<String> subprotocols;
    private final String joinedSubprotocols;
    private final boolean aggregateContinuation;

    DefaultWebSocketClient(WebClient webClient, int maxFramePayloadLength, boolean allowMaskMismatch, List<String> subprotocols, boolean aggregateContinuation) {
        this.webClient = webClient;
        this.maxFramePayloadLength = maxFramePayloadLength;
        this.allowMaskMismatch = allowMaskMismatch;
        this.subprotocols = subprotocols;
        this.joinedSubprotocols = !subprotocols.isEmpty() ? Joiner.on(", ").join(subprotocols) : "";
        this.aggregateContinuation = aggregateContinuation;
    }

    @Override
    public CompletableFuture<WebSocketSession> connect(String path, HttpHeaders headers, RequestOptions requestOptions) {
        ClientRequestContext ctx;
        HttpResponse response;
        Objects.requireNonNull(path, "path");
        RequestHeaders requestHeaders = this.webSocketHeaders(path, headers);
        CompletableFuture outboundFuture = new CompletableFuture();
        HttpRequest request = HttpRequest.of(requestHeaders, StreamMessage.of(outboundFuture));
        try (ClientRequestContextCaptor captor = Clients.newContextCaptor();){
            response = this.webClient.execute(request, requestOptions);
            ctx = captor.get();
        }
        DefaultSplitHttpResponse split = new DefaultSplitHttpResponse(response, (EventExecutor)ctx.eventLoop(), responseHeaders -> {
            SessionProtocol actualSessionProtocol = DefaultWebSocketClient.actualSessionProtocol(ctx);
            if (actualSessionProtocol.isExplicitHttp1()) {
                return true;
            }
            assert (actualSessionProtocol.isExplicitHttp2());
            return !responseHeaders.status().isInformational();
        });
        CompletableFuture<WebSocketSession> result = new CompletableFuture<WebSocketSession>();
        split.headers().handle((responseHeaders, cause) -> {
            if (cause != null) {
                DefaultWebSocketClient.fail(outboundFuture, split.body(), result, cause);
                return null;
            }
            if (!this.validateResponseHeaders(ctx, requestHeaders, (ResponseHeaders)responseHeaders, outboundFuture, split.body(), result)) {
                return null;
            }
            WebSocketClientFrameDecoder decoder = new WebSocketClientFrameDecoder(ctx, this.maxFramePayloadLength, this.allowMaskMismatch, this.aggregateContinuation);
            WebSocketWrapper inbound = new WebSocketWrapper(split.body().decode(decoder, ctx.alloc()));
            result.complete(new WebSocketSession(ctx, (ResponseHeaders)responseHeaders, inbound, outboundFuture, encoder));
            return null;
        });
        return result;
    }

    private RequestHeaders webSocketHeaders(String path, HttpHeaders headers) {
        RequestHeadersBuilder builder = RequestHeaders.builder();
        if (!headers.isEmpty()) {
            headers.forEach((k, v) -> builder.add((CharSequence)k, (String)v));
        }
        if (this.scheme().sessionProtocol().isExplicitHttp2()) {
            builder.method(HttpMethod.CONNECT).path(path).set((CharSequence)HttpHeaderNames.PROTOCOL, HttpHeaderValues.WEBSOCKET.toString());
        } else {
            String secWebSocketKey = DefaultWebSocketClient.generateSecWebSocketKey();
            builder.method(HttpMethod.GET).path(path).set((CharSequence)HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()).set((CharSequence)HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()).set((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_KEY, secWebSocketKey);
        }
        builder.set((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13");
        if (!builder.contains(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL) && !this.subprotocols.isEmpty()) {
            builder.set((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, this.joinedSubprotocols);
        }
        return builder.build();
    }

    private boolean validateResponseHeaders(ClientRequestContext ctx, RequestHeaders requestHeaders, ResponseHeaders responseHeaders, CompletableFuture<StreamMessage<HttpData>> outboundFuture, ByteStreamMessage responseBody, CompletableFuture<WebSocketSession> result) {
        String responseSubprotocol;
        if (DefaultWebSocketClient.actualSessionProtocol(ctx).isExplicitHttp2()) {
            HttpStatus status = responseHeaders.status();
            if (status != HttpStatus.OK) {
                DefaultWebSocketClient.fail(outboundFuture, responseBody, result, new WebSocketClientHandshakeException("invalid status: " + status + " (expected: " + HttpStatus.OK + ')', responseHeaders));
                return false;
            }
        } else {
            if (!DefaultWebSocketClient.isHttp1WebSocketResponse(responseHeaders)) {
                DefaultWebSocketClient.fail(outboundFuture, responseBody, result, new WebSocketClientHandshakeException("invalid response headers: " + responseHeaders, responseHeaders));
                return false;
            }
            String secWebSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
            assert (secWebSocketKey != null);
            String secWebSocketAccept = responseHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
            if (secWebSocketAccept == null) {
                DefaultWebSocketClient.fail(outboundFuture, responseBody, result, new WebSocketClientHandshakeException(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT + " is null.", responseHeaders));
                return false;
            }
            if (!secWebSocketAccept.equals(WebSocketUtil.generateSecWebSocketAccept(secWebSocketKey))) {
                DefaultWebSocketClient.fail(outboundFuture, responseBody, result, new WebSocketClientHandshakeException("invalid " + HttpHeaderNames.SEC_WEBSOCKET_ACCEPT + " header: " + secWebSocketAccept, responseHeaders));
                return false;
            }
        }
        if (!this.subprotocols.isEmpty() && (responseSubprotocol = responseHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)) != null && !this.subprotocols.contains(responseSubprotocol)) {
            DefaultWebSocketClient.fail(outboundFuture, responseBody, result, new WebSocketClientHandshakeException("invalid " + HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL + " header: " + responseSubprotocol + " (expected: one of " + this.subprotocols + ')', responseHeaders));
            return false;
        }
        return true;
    }

    private static SessionProtocol actualSessionProtocol(ClientRequestContext ctx) {
        return ctx.log().ensureAvailable(RequestLogProperty.SESSION).sessionProtocol();
    }

    private static void fail(CompletableFuture<StreamMessage<HttpData>> outboundFuture, ByteStreamMessage responseBody, CompletableFuture<WebSocketSession> result, Throwable cause) {
        outboundFuture.completeExceptionally(cause);
        responseBody.abort(cause);
        result.completeExceptionally(cause);
    }

    static String generateSecWebSocketKey() {
        byte[] bytes = new byte[16];
        ThreadLocalRandom.current().nextBytes(bytes);
        return Base64.getEncoder().encodeToString(bytes);
    }

    private static boolean isHttp1WebSocketResponse(ResponseHeaders responseHeaders) {
        return responseHeaders.status() == HttpStatus.SWITCHING_PROTOCOLS && HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(responseHeaders.get(HttpHeaderNames.UPGRADE)) && HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase(responseHeaders.get(HttpHeaderNames.CONNECTION));
    }

    @Override
    public Scheme scheme() {
        return this.webClient.scheme();
    }

    @Override
    public EndpointGroup endpointGroup() {
        return this.webClient.endpointGroup();
    }

    @Override
    public String absolutePathRef() {
        return this.webClient.absolutePathRef();
    }

    @Override
    public URI uri() {
        return this.webClient.uri();
    }

    @Override
    public Class<?> clientType() {
        return this.webClient.clientType();
    }

    @Override
    public ClientOptions options() {
        return this.webClient.options();
    }

    @Override
    public WebClient unwrap() {
        return this.webClient;
    }
}

