/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.input.parameter.regression;

import java.io.IOException;
import java.util.Locale;
import lombok.Generated;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;

@MLAlgoParameter(algorithms={FunctionName.LOGISTIC_REGRESSION})
public class LogisticRegressionParams
implements MLAlgoParams {
    public static final String PARSE_FIELD_NAME = FunctionName.LOGISTIC_REGRESSION.name();
    public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME, new String[0]), it -> LogisticRegressionParams.parse(it));
    public static final String OBJECTIVE_FIELD = "objective";
    public static final String OPTIMISER_FIELD = "optimiser";
    public static final String MOMENTUM_TYPE_FIELD = "momentum_type";
    public static final String LEARNING_RATE_FIELD = "learning_rate";
    public static final String EPSILON_FIELD = "epsilon";
    public static final String MOMENTUM_FACTOR_FIELD = "momentum_factor";
    public static final String BETA1_FIELD = "beta1";
    public static final String BETA2_FIELD = "beta2";
    public static final String DECAY_RATE_FIELD = "decay_rate";
    public static final String EPOCHS_FIELD = "epochs";
    public static final String BATCH_SIZE_FIELD = "batch_size";
    public static final String LOGGING_INTERVAL_FIELD = "logging_interval";
    public static final String SEED_FIELD = "seed";
    public static final String TARGET_FIELD = "target";
    private ObjectiveType objectiveType;
    private OptimizerType optimizerType;
    private MomentumType momentumType;
    private Double learningRate;
    private Double epsilon;
    private Double momentumFactor;
    private Double beta1;
    private Double beta2;
    private Double decayRate;
    private Integer epochs;
    private Integer batchSize;
    private Integer loggingInterval;
    private Long seed;
    private String target;

    public LogisticRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, MomentumType momentumType, Double learningRate, Double epsilon, Double momentumFactor, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Integer loggingInterval, Long seed, String target) {
        this.objectiveType = objectiveType;
        this.optimizerType = optimizerType;
        this.momentumType = momentumType;
        this.learningRate = learningRate;
        this.epsilon = epsilon;
        this.momentumFactor = momentumFactor;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.decayRate = decayRate;
        this.epochs = epochs;
        this.batchSize = batchSize;
        this.loggingInterval = loggingInterval;
        this.seed = seed;
        this.target = target;
    }

    public LogisticRegressionParams(StreamInput in) throws IOException {
        if (in.readBoolean()) {
            this.objectiveType = (ObjectiveType)in.readEnum(ObjectiveType.class);
        }
        if (in.readBoolean()) {
            this.optimizerType = (OptimizerType)in.readEnum(OptimizerType.class);
        }
        if (in.readBoolean()) {
            this.momentumType = (MomentumType)in.readEnum(MomentumType.class);
        }
        this.learningRate = in.readOptionalDouble();
        this.epsilon = in.readOptionalDouble();
        this.momentumFactor = in.readOptionalDouble();
        this.beta1 = in.readOptionalDouble();
        this.beta2 = in.readOptionalDouble();
        this.decayRate = in.readOptionalDouble();
        this.epochs = in.readOptionalInt();
        this.batchSize = in.readOptionalInt();
        this.loggingInterval = in.readOptionalInt();
        this.seed = in.readOptionalLong();
        this.target = in.readOptionalString();
    }

    public static MLAlgoParams parse(XContentParser parser) throws IOException {
        ObjectiveType objective = null;
        OptimizerType optimizerType = null;
        MomentumType momentumType = null;
        Double learningRate = null;
        Double epsilon = null;
        Double momentumFactor = null;
        Double beta1 = null;
        Double beta2 = null;
        Double decayRate = null;
        Integer epochs = null;
        Integer batchSize = null;
        Integer loggingInterval = null;
        Long seed = null;
        String target = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block32: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "objective": {
                    objective = ObjectiveType.valueOf(parser.text().toUpperCase(Locale.ROOT));
                    continue block32;
                }
                case "optimiser": {
                    optimizerType = OptimizerType.valueOf(parser.text().toUpperCase(Locale.ROOT));
                    continue block32;
                }
                case "momentum_type": {
                    momentumType = MomentumType.valueOf(parser.text().toUpperCase(Locale.ROOT));
                    continue block32;
                }
                case "learning_rate": {
                    learningRate = parser.doubleValue(false);
                    continue block32;
                }
                case "epsilon": {
                    epsilon = parser.doubleValue(false);
                    continue block32;
                }
                case "momentum_factor": {
                    momentumFactor = parser.doubleValue(false);
                    continue block32;
                }
                case "beta1": {
                    beta1 = parser.doubleValue(false);
                    continue block32;
                }
                case "beta2": {
                    beta2 = parser.doubleValue(false);
                    continue block32;
                }
                case "decay_rate": {
                    decayRate = parser.doubleValue(false);
                    continue block32;
                }
                case "epochs": {
                    epochs = parser.intValue(false);
                    continue block32;
                }
                case "batch_size": {
                    batchSize = parser.intValue(false);
                    continue block32;
                }
                case "logging_interval": {
                    loggingInterval = parser.intValue(false);
                    continue block32;
                }
                case "seed": {
                    seed = parser.longValue(false);
                    continue block32;
                }
                case "target": {
                    target = parser.text();
                    continue block32;
                }
            }
            parser.skipChildren();
        }
        return new LogisticRegressionParams(objective, optimizerType, momentumType, learningRate, epsilon, momentumFactor, beta1, beta2, decayRate, epochs, batchSize, loggingInterval, seed, target);
    }

    public void writeTo(StreamOutput out) throws IOException {
        if (this.objectiveType != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.objectiveType);
        } else {
            out.writeBoolean(false);
        }
        if (this.optimizerType != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.optimizerType);
        } else {
            out.writeBoolean(false);
        }
        if (this.momentumType != null) {
            out.writeBoolean(true);
            out.writeEnum((Enum)this.momentumType);
        } else {
            out.writeBoolean(false);
        }
        out.writeOptionalDouble(this.learningRate);
        out.writeOptionalDouble(this.epsilon);
        out.writeOptionalDouble(this.momentumFactor);
        out.writeOptionalDouble(this.beta1);
        out.writeOptionalDouble(this.beta2);
        out.writeOptionalDouble(this.decayRate);
        out.writeOptionalInt(this.epochs);
        out.writeOptionalInt(this.batchSize);
        out.writeOptionalInt(this.loggingInterval);
        out.writeOptionalLong(this.seed);
        out.writeOptionalString(this.target);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.objectiveType != null) {
            builder.field(OBJECTIVE_FIELD, (Object)this.objectiveType);
        }
        if (this.optimizerType != null) {
            builder.field(OPTIMISER_FIELD, (Object)this.optimizerType);
        }
        if (this.momentumType != null) {
            builder.field(MOMENTUM_TYPE_FIELD, (Object)this.momentumType);
        }
        if (this.learningRate != null) {
            builder.field(LEARNING_RATE_FIELD, this.learningRate);
        }
        if (this.epsilon != null) {
            builder.field(EPSILON_FIELD, this.epsilon);
        }
        if (this.momentumFactor != null) {
            builder.field(MOMENTUM_FACTOR_FIELD, this.momentumFactor);
        }
        if (this.beta1 != null) {
            builder.field(BETA1_FIELD, this.beta1);
        }
        if (this.beta2 != null) {
            builder.field(BETA2_FIELD, this.beta2);
        }
        if (this.decayRate != null) {
            builder.field(DECAY_RATE_FIELD, this.decayRate);
        }
        if (this.epochs != null) {
            builder.field(EPOCHS_FIELD, this.epochs);
        }
        if (this.batchSize != null) {
            builder.field(BATCH_SIZE_FIELD, this.batchSize);
        }
        if (this.loggingInterval != null) {
            builder.field(LOGGING_INTERVAL_FIELD, this.loggingInterval);
        }
        if (this.seed != null) {
            builder.field(SEED_FIELD, this.seed);
        }
        if (this.target != null) {
            builder.field(TARGET_FIELD, this.target);
        }
        builder.endObject();
        return builder;
    }

    public String getWriteableName() {
        return PARSE_FIELD_NAME;
    }

    @Override
    public int getVersion() {
        return 1;
    }

    @Generated
    public static LogisticRegressionParamsBuilder builder() {
        return new LogisticRegressionParamsBuilder();
    }

    @Generated
    public LogisticRegressionParamsBuilder toBuilder() {
        return new LogisticRegressionParamsBuilder().objectiveType(this.objectiveType).optimizerType(this.optimizerType).momentumType(this.momentumType).learningRate(this.learningRate).epsilon(this.epsilon).momentumFactor(this.momentumFactor).beta1(this.beta1).beta2(this.beta2).decayRate(this.decayRate).epochs(this.epochs).batchSize(this.batchSize).loggingInterval(this.loggingInterval).seed(this.seed).target(this.target);
    }

    @Generated
    public ObjectiveType getObjectiveType() {
        return this.objectiveType;
    }

    @Generated
    public OptimizerType getOptimizerType() {
        return this.optimizerType;
    }

    @Generated
    public MomentumType getMomentumType() {
        return this.momentumType;
    }

    @Generated
    public Double getLearningRate() {
        return this.learningRate;
    }

    @Generated
    public Double getEpsilon() {
        return this.epsilon;
    }

    @Generated
    public Double getMomentumFactor() {
        return this.momentumFactor;
    }

    @Generated
    public Double getBeta1() {
        return this.beta1;
    }

    @Generated
    public Double getBeta2() {
        return this.beta2;
    }

    @Generated
    public Double getDecayRate() {
        return this.decayRate;
    }

    @Generated
    public Integer getEpochs() {
        return this.epochs;
    }

    @Generated
    public Integer getBatchSize() {
        return this.batchSize;
    }

    @Generated
    public Integer getLoggingInterval() {
        return this.loggingInterval;
    }

    @Generated
    public Long getSeed() {
        return this.seed;
    }

    @Generated
    public String getTarget() {
        return this.target;
    }

    @Generated
    public void setObjectiveType(ObjectiveType objectiveType) {
        this.objectiveType = objectiveType;
    }

    @Generated
    public void setOptimizerType(OptimizerType optimizerType) {
        this.optimizerType = optimizerType;
    }

    @Generated
    public void setMomentumType(MomentumType momentumType) {
        this.momentumType = momentumType;
    }

    @Generated
    public void setLearningRate(Double learningRate) {
        this.learningRate = learningRate;
    }

    @Generated
    public void setEpsilon(Double epsilon) {
        this.epsilon = epsilon;
    }

    @Generated
    public void setMomentumFactor(Double momentumFactor) {
        this.momentumFactor = momentumFactor;
    }

    @Generated
    public void setBeta1(Double beta1) {
        this.beta1 = beta1;
    }

    @Generated
    public void setBeta2(Double beta2) {
        this.beta2 = beta2;
    }

    @Generated
    public void setDecayRate(Double decayRate) {
        this.decayRate = decayRate;
    }

    @Generated
    public void setEpochs(Integer epochs) {
        this.epochs = epochs;
    }

    @Generated
    public void setBatchSize(Integer batchSize) {
        this.batchSize = batchSize;
    }

    @Generated
    public void setLoggingInterval(Integer loggingInterval) {
        this.loggingInterval = loggingInterval;
    }

    @Generated
    public void setSeed(Long seed) {
        this.seed = seed;
    }

    @Generated
    public void setTarget(String target) {
        this.target = target;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LogisticRegressionParams)) {
            return false;
        }
        LogisticRegressionParams other = (LogisticRegressionParams)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Double this$learningRate = this.getLearningRate();
        Double other$learningRate = other.getLearningRate();
        if (this$learningRate == null ? other$learningRate != null : !((Object)this$learningRate).equals(other$learningRate)) {
            return false;
        }
        Double this$epsilon = this.getEpsilon();
        Double other$epsilon = other.getEpsilon();
        if (this$epsilon == null ? other$epsilon != null : !((Object)this$epsilon).equals(other$epsilon)) {
            return false;
        }
        Double this$momentumFactor = this.getMomentumFactor();
        Double other$momentumFactor = other.getMomentumFactor();
        if (this$momentumFactor == null ? other$momentumFactor != null : !((Object)this$momentumFactor).equals(other$momentumFactor)) {
            return false;
        }
        Double this$beta1 = this.getBeta1();
        Double other$beta1 = other.getBeta1();
        if (this$beta1 == null ? other$beta1 != null : !((Object)this$beta1).equals(other$beta1)) {
            return false;
        }
        Double this$beta2 = this.getBeta2();
        Double other$beta2 = other.getBeta2();
        if (this$beta2 == null ? other$beta2 != null : !((Object)this$beta2).equals(other$beta2)) {
            return false;
        }
        Double this$decayRate = this.getDecayRate();
        Double other$decayRate = other.getDecayRate();
        if (this$decayRate == null ? other$decayRate != null : !((Object)this$decayRate).equals(other$decayRate)) {
            return false;
        }
        Integer this$epochs = this.getEpochs();
        Integer other$epochs = other.getEpochs();
        if (this$epochs == null ? other$epochs != null : !((Object)this$epochs).equals(other$epochs)) {
            return false;
        }
        Integer this$batchSize = this.getBatchSize();
        Integer other$batchSize = other.getBatchSize();
        if (this$batchSize == null ? other$batchSize != null : !((Object)this$batchSize).equals(other$batchSize)) {
            return false;
        }
        Integer this$loggingInterval = this.getLoggingInterval();
        Integer other$loggingInterval = other.getLoggingInterval();
        if (this$loggingInterval == null ? other$loggingInterval != null : !((Object)this$loggingInterval).equals(other$loggingInterval)) {
            return false;
        }
        Long this$seed = this.getSeed();
        Long other$seed = other.getSeed();
        if (this$seed == null ? other$seed != null : !((Object)this$seed).equals(other$seed)) {
            return false;
        }
        ObjectiveType this$objectiveType = this.getObjectiveType();
        ObjectiveType other$objectiveType = other.getObjectiveType();
        if (this$objectiveType == null ? other$objectiveType != null : !((Object)((Object)this$objectiveType)).equals((Object)other$objectiveType)) {
            return false;
        }
        OptimizerType this$optimizerType = this.getOptimizerType();
        OptimizerType other$optimizerType = other.getOptimizerType();
        if (this$optimizerType == null ? other$optimizerType != null : !((Object)((Object)this$optimizerType)).equals((Object)other$optimizerType)) {
            return false;
        }
        MomentumType this$momentumType = this.getMomentumType();
        MomentumType other$momentumType = other.getMomentumType();
        if (this$momentumType == null ? other$momentumType != null : !((Object)((Object)this$momentumType)).equals((Object)other$momentumType)) {
            return false;
        }
        String this$target = this.getTarget();
        String other$target = other.getTarget();
        return !(this$target == null ? other$target != null : !this$target.equals(other$target));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof LogisticRegressionParams;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Double $learningRate = this.getLearningRate();
        result = result * 59 + ($learningRate == null ? 43 : ((Object)$learningRate).hashCode());
        Double $epsilon = this.getEpsilon();
        result = result * 59 + ($epsilon == null ? 43 : ((Object)$epsilon).hashCode());
        Double $momentumFactor = this.getMomentumFactor();
        result = result * 59 + ($momentumFactor == null ? 43 : ((Object)$momentumFactor).hashCode());
        Double $beta1 = this.getBeta1();
        result = result * 59 + ($beta1 == null ? 43 : ((Object)$beta1).hashCode());
        Double $beta2 = this.getBeta2();
        result = result * 59 + ($beta2 == null ? 43 : ((Object)$beta2).hashCode());
        Double $decayRate = this.getDecayRate();
        result = result * 59 + ($decayRate == null ? 43 : ((Object)$decayRate).hashCode());
        Integer $epochs = this.getEpochs();
        result = result * 59 + ($epochs == null ? 43 : ((Object)$epochs).hashCode());
        Integer $batchSize = this.getBatchSize();
        result = result * 59 + ($batchSize == null ? 43 : ((Object)$batchSize).hashCode());
        Integer $loggingInterval = this.getLoggingInterval();
        result = result * 59 + ($loggingInterval == null ? 43 : ((Object)$loggingInterval).hashCode());
        Long $seed = this.getSeed();
        result = result * 59 + ($seed == null ? 43 : ((Object)$seed).hashCode());
        ObjectiveType $objectiveType = this.getObjectiveType();
        result = result * 59 + ($objectiveType == null ? 43 : ((Object)((Object)$objectiveType)).hashCode());
        OptimizerType $optimizerType = this.getOptimizerType();
        result = result * 59 + ($optimizerType == null ? 43 : ((Object)((Object)$optimizerType)).hashCode());
        MomentumType $momentumType = this.getMomentumType();
        result = result * 59 + ($momentumType == null ? 43 : ((Object)((Object)$momentumType)).hashCode());
        String $target = this.getTarget();
        result = result * 59 + ($target == null ? 43 : $target.hashCode());
        return result;
    }

    @Generated
    public String toString() {
        return "LogisticRegressionParams(objectiveType=" + String.valueOf((Object)this.getObjectiveType()) + ", optimizerType=" + String.valueOf((Object)this.getOptimizerType()) + ", momentumType=" + String.valueOf((Object)this.getMomentumType()) + ", learningRate=" + this.getLearningRate() + ", epsilon=" + this.getEpsilon() + ", momentumFactor=" + this.getMomentumFactor() + ", beta1=" + this.getBeta1() + ", beta2=" + this.getBeta2() + ", decayRate=" + this.getDecayRate() + ", epochs=" + this.getEpochs() + ", batchSize=" + this.getBatchSize() + ", loggingInterval=" + this.getLoggingInterval() + ", seed=" + this.getSeed() + ", target=" + this.getTarget() + ")";
    }

    public static enum ObjectiveType {
        HINGE,
        LOGMULTICLASS;


        public static ObjectiveType from(String value) {
            try {
                return ObjectiveType.valueOf(value);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong objective type");
            }
        }
    }

    public static enum OptimizerType {
        SIMPLE_SGD,
        LINEAR_DECAY_SGD,
        SQRT_DECAY_SGD,
        ADA_GRAD,
        ADA_DELTA,
        ADAM,
        RMS_PROP;


        public static OptimizerType from(String value) {
            try {
                return OptimizerType.valueOf(value);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong optimizer type");
            }
        }
    }

    public static enum MomentumType {
        STANDARD,
        NESTEROV;


        public static MomentumType from(String value) {
            try {
                return MomentumType.valueOf(value);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Wrong momentum type");
            }
        }
    }

    @Generated
    public static class LogisticRegressionParamsBuilder {
        @Generated
        private ObjectiveType objectiveType;
        @Generated
        private OptimizerType optimizerType;
        @Generated
        private MomentumType momentumType;
        @Generated
        private Double learningRate;
        @Generated
        private Double epsilon;
        @Generated
        private Double momentumFactor;
        @Generated
        private Double beta1;
        @Generated
        private Double beta2;
        @Generated
        private Double decayRate;
        @Generated
        private Integer epochs;
        @Generated
        private Integer batchSize;
        @Generated
        private Integer loggingInterval;
        @Generated
        private Long seed;
        @Generated
        private String target;

        @Generated
        LogisticRegressionParamsBuilder() {
        }

        @Generated
        public LogisticRegressionParamsBuilder objectiveType(ObjectiveType objectiveType) {
            this.objectiveType = objectiveType;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder optimizerType(OptimizerType optimizerType) {
            this.optimizerType = optimizerType;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder momentumType(MomentumType momentumType) {
            this.momentumType = momentumType;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder learningRate(Double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder epsilon(Double epsilon) {
            this.epsilon = epsilon;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder momentumFactor(Double momentumFactor) {
            this.momentumFactor = momentumFactor;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder beta1(Double beta1) {
            this.beta1 = beta1;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder beta2(Double beta2) {
            this.beta2 = beta2;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder decayRate(Double decayRate) {
            this.decayRate = decayRate;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder epochs(Integer epochs) {
            this.epochs = epochs;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder batchSize(Integer batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder loggingInterval(Integer loggingInterval) {
            this.loggingInterval = loggingInterval;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder seed(Long seed) {
            this.seed = seed;
            return this;
        }

        @Generated
        public LogisticRegressionParamsBuilder target(String target) {
            this.target = target;
            return this;
        }

        @Generated
        public LogisticRegressionParams build() {
            return new LogisticRegressionParams(this.objectiveType, this.optimizerType, this.momentumType, this.learningRate, this.epsilon, this.momentumFactor, this.beta1, this.beta2, this.decayRate, this.epochs, this.batchSize, this.loggingInterval, this.seed, this.target);
        }

        @Generated
        public String toString() {
            return "LogisticRegressionParams.LogisticRegressionParamsBuilder(objectiveType=" + String.valueOf((Object)this.objectiveType) + ", optimizerType=" + String.valueOf((Object)this.optimizerType) + ", momentumType=" + String.valueOf((Object)this.momentumType) + ", learningRate=" + this.learningRate + ", epsilon=" + this.epsilon + ", momentumFactor=" + this.momentumFactor + ", beta1=" + this.beta1 + ", beta2=" + this.beta2 + ", decayRate=" + this.decayRate + ", epochs=" + this.epochs + ", batchSize=" + this.batchSize + ", loggingInterval=" + this.loggingInterval + ", seed=" + this.seed + ", target=" + this.target + ")";
        }
    }
}

