package ai.catboost.spark;

import ai.catboost.CatBoostError;
import ai.catboost.spark.impl.CatBoostMasterWrapper;
import ai.catboost.spark.impl.CatBoostMasterWrapper$;
import ai.catboost.spark.impl.CatBoostWorkers;
import ai.catboost.spark.impl.CatBoostWorkers$;
import ai.catboost.spark.impl.CtrFeatures$;
import ai.catboost.spark.impl.CtrsContext;
import ai.catboost.spark.params.DatasetParamsTrait;
import ai.catboost.spark.params.Helpers$;
import ai.catboost.spark.params.QuantizationParams;
import ai.catboost.spark.params.TrainingParamsTrait;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.json4s.JsonAST;
import org.json4s.jackson.JsonMethods$;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TClassTargetPreprocessor;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFullModel;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TVector_i8;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.native_impl;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.util.control.Breaks$;

/* compiled from: CatBoostPredictor.scala */
@ScalaSignature(bytes = "\u0006\u0005\u0005ee!\u0003\u0006\f!\u0003\r\tAEAF\u0011\u0015\u0019\u0005\u0001\"\u0001E\u0011\u0015A\u0005\u0001\"\u0005J\u0011%\tI\u0002AI\u0001\n#\tY\u0002C\u0005\u00022\u0001\t\n\u0011\"\u0005\u00024!9\u0011q\u0007\u0001\u0005\u0012\u0005e\u0002bBA$\u0001\u0019E\u0011\u0011\n\u0005\b\u0003+\u0002A\u0011KA,\u0011\u001d\tI\b\u0001C\u0001\u0003wB\u0011\"!\"\u0001#\u0003%\t!a\"\u0003-\r\u000bGOQ8pgR\u0004&/\u001a3jGR|'\u000f\u0016:bSRT!\u0001D\u0007\u0002\u000bM\u0004\u0018M]6\u000b\u00059y\u0011\u0001C2bi\n|wn\u001d;\u000b\u0003A\t!!Y5\u0004\u0001U\u00191cJ\u0019\u0014\t\u0001!r'\u0010\t\u0006+uyR\u0005M\u0007\u0002-)\u0011q\u0003G\u0001\u0003[2T!\u0001D\r\u000b\u0005iY\u0012AB1qC\u000eDWMC\u0001\u001d\u0003\ry'oZ\u0005\u0003=Y\u0011\u0011\u0002\u0015:fI&\u001cGo\u001c:\u0011\u0005\u0001\u001aS\"A\u0011\u000b\u0005\t2\u0012A\u00027j]\u0006dw-\u0003\u0002%C\t1a+Z2u_J\u0004\"AJ\u0014\r\u0001\u0011)\u0001\u0006\u0001b\u0001S\t9A*Z1s]\u0016\u0014\u0018C\u0001\u0016\u0015!\tYc&D\u0001-\u0015\u0005i\u0013!B:dC2\f\u0017BA\u0018-\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"AJ\u0019\u0005\u000bI\u0002!\u0019A\u001a\u0003\u000b5{G-\u001a7\u0012\u0005)\"\u0004\u0003B\u000b6?AJ!A\u000e\f\u0003\u001fA\u0013X\rZ5di&|g.T8eK2\u0004\"\u0001O\u001e\u000e\u0003eR!AO\u0006\u0002\rA\f'/Y7t\u0013\ta\u0014H\u0001\nECR\f7/\u001a;QCJ\fWn\u001d+sC&$\bC\u0001 B\u001b\u0005y$B\u0001!\u0017\u0003\u0011)H/\u001b7\n\u0005\t{$!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7o\u0016:ji\u0006\u0014G.Z\u0001\u0007I%t\u0017\u000e\u001e\u0013\u0015\u0003\u0015\u0003\"a\u000b$\n\u0005\u001dc#\u0001B+oSR\fq#\u00193e\u000bN$\u0018.\\1uK\u0012\u001cEO\u001d$fCR,(/Z:\u0015\u000f)SFL\u00189\u0002\u0010A)1fS'R)&\u0011A\n\f\u0002\u0007)V\u0004H.Z\u001a\u0011\u00059{U\"A\u0006\n\u0005A[!\u0001\u0002)p_2\u00042a\u000b*N\u0013\t\u0019FFA\u0003BeJ\f\u0017\u0010\u0005\u0002V16\taK\u0003\u0002X\u0017\u0005!\u0011.\u001c9m\u0013\tIfKA\u0006DiJ\u001c8i\u001c8uKb$\b\"B.\u0003\u0001\u0004i\u0015AE9vC:$\u0018N_3e)J\f\u0017N\u001c)p_2DQ!\u0018\u0002A\u0002E\u000b!#];b]RL'0\u001a3Fm\u0006d\u0007k\\8mg\")qL\u0001a\u0001A\u0006IR\u000f\u001d3bi\u0016$7)\u0019;C_>\u001cHOS:p]B\u000b'/Y7t!\t\tWN\u0004\u0002cU:\u00111\r\u001b\b\u0003I\u001el\u0011!\u001a\u0006\u0003MF\ta\u0001\u0010:p_Rt\u0014\"\u0001\u000f\n\u0005%\\\u0012A\u00026t_:$4/\u0003\u0002lY\u00069\u0001/Y2lC\u001e,'BA5\u001c\u0013\tqwNA\u0004K\u001f\nTWm\u0019;\u000b\u0005-d\u0007bB9\u0003!\u0003\u0005\rA]\u0001\u0018G2\f7o\u001d+be\u001e,G\u000f\u0015:faJ|7-Z:t_J\u00042aK:v\u0013\t!HF\u0001\u0004PaRLwN\u001c\t\u0004m\u0006-Q\"A<\u000b\u0005aL\u0018a\u00038bi&4XmX5na2T!A_>\u0002\u0007M\u00148M\u0003\u0002}{\u0006!1m\u001c:f\u0015\tqx0\u0001\tdCR\u0014wn\\:ui)|6\u000f]1sW*\u0019A\"!\u0001\u000b\u00079\t\u0019A\u0003\u0003\u0002\u0006\u0005\u001d\u0011AB=b]\u0012,\u0007P\u0003\u0002\u0002\n\u0005\u0011!/^\u0005\u0004\u0003\u001b9(\u0001\u0007+DY\u0006\u001c8\u000fV1sO\u0016$\bK]3qe>\u001cWm]:pe\"I\u0011\u0011\u0003\u0002\u0011\u0002\u0003\u0007\u00111C\u0001\u0019g\u0016\u0014\u0018.\u00197ju\u0016$G*\u00192fY\u000e{gN^3si\u0016\u0014\bc\u0001<\u0002\u0016%\u0019\u0011qC<\u0003\u0015Q3Vm\u0019;pe~K\u0007(A\u0011bI\u0012,5\u000f^5nCR,Gm\u0011;s\r\u0016\fG/\u001e:fg\u0012\"WMZ1vYR$C'\u0006\u0002\u0002\u001e)\u001a!/a\b,\u0005\u0005\u0005\u0002\u0003BA\u0012\u0003[i!!!\n\u000b\t\u0005\u001d\u0012\u0011F\u0001\nk:\u001c\u0007.Z2lK\u0012T1!a\u000b-\u0003)\tgN\\8uCRLwN\\\u0005\u0005\u0003_\t)CA\tv]\u000eDWmY6fIZ\u000b'/[1oG\u0016\f\u0011%\u00193e\u000bN$\u0018.\\1uK\u0012\u001cEO\u001d$fCR,(/Z:%I\u00164\u0017-\u001e7uIU*\"!!\u000e+\t\u0005M\u0011qD\u0001\u0019aJ,\u0007O]8dKN\u001c()\u001a4pe\u0016$&/Y5oS:<GCBA\u001e\u0003\u0007\n)\u0005\u0005\u0004,\u00176\u000b\u0016Q\b\t\u0004\u001d\u0006}\u0012bAA!\u0017\t92)\u0019;C_>\u001cH\u000f\u0016:bS:LgnZ\"p]R,\u0007\u0010\u001e\u0005\u00067\u0016\u0001\r!\u0014\u0005\u0006;\u0016\u0001\r!U\u0001\fGJ,\u0017\r^3N_\u0012,G\u000eF\u00021\u0003\u0017Bq!!\u0014\u0007\u0001\u0004\ty%A\u0005gk2dWj\u001c3fYB\u0019a/!\u0015\n\u0007\u0005MsO\u0001\u0006U\rVdG.T8eK2\fQ\u0001\u001e:bS:$2\u0001MA-\u0011\u001d\tYf\u0002a\u0001\u0003;\nq\u0001Z1uCN,G\u000f\r\u0003\u0002`\u00055\u0004CBA1\u0003O\nY'\u0004\u0002\u0002d)\u0019\u0011Q\r\r\u0002\u0007M\fH.\u0003\u0003\u0002j\u0005\r$a\u0002#bi\u0006\u001cX\r\u001e\t\u0004M\u00055D\u0001DA8\u00033\n\t\u0011!A\u0003\u0002\u0005E$aA0%cE\u0019!&a\u001d\u0011\u0007-\n)(C\u0002\u0002x1\u00121!\u00118z\u0003\r1\u0017\u000e\u001e\u000b\u0006a\u0005u\u0014\u0011\u0011\u0005\u0007\u0003\u007fB\u0001\u0019A'\u0002\u0013Q\u0014\u0018-\u001b8Q_>d\u0007\u0002CAB\u0011A\u0005\t\u0019A)\u0002\u0013\u00154\u0018\r\u001c)p_2\u001c\u0018!\u00044ji\u0012\"WMZ1vYR$#'\u0006\u0002\u0002\n*\u001a\u0011+a\b\u0013\r\u00055\u0015\u0011SAJ\r\u0019\ty\t\u0001\u0001\u0002\f\naAH]3gS:,W.\u001a8u}A!a\nA\u00131!\rA\u0014QS\u0005\u0004\u0003/K$a\u0005+sC&t\u0017N\\4QCJ\fWn\u001d+sC&$\b")
/* loaded from: input_file:ai/catboost/spark/CatBoostPredictorTrait.class */
public interface CatBoostPredictorTrait<Learner extends Predictor<Vector, Learner, Model>, Model extends PredictionModel<Vector, Model>> extends DatasetParamsTrait, DefaultParamsWritable {
    default Tuple3<Pool, Pool[], CtrsContext> addEstimatedCtrFeatures(Pool pool, Pool[] poolArr, JsonAST.JObject jObject, Option<TClassTargetPreprocessor> option, TVector_i8 tVector_i8) {
        int CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn = native_impl.CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn(pool.quantizedFeaturesInfo().__deref__());
        int GetOneHotMaxSize = native_impl.GetOneHotMaxSize(CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn, pool.isDefined(pool.labelCol()), JsonMethods$.MODULE$.compact(jObject));
        return CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn > GetOneHotMaxSize ? CtrFeatures$.MODULE$.addCtrsAsEstimated(pool, poolArr, jObject, GetOneHotMaxSize, option, tVector_i8) : new Tuple3<>(pool, poolArr, (Object) null);
    }

    default Option<TClassTargetPreprocessor> addEstimatedCtrFeatures$default$4() {
        return None$.MODULE$;
    }

    default TVector_i8 addEstimatedCtrFeatures$default$5() {
        return new TVector_i8();
    }

    default Tuple3<Pool, Pool[], CatBoostTrainingContext> preprocessBeforeTraining(Pool pool, Pool[] poolArr) {
        JsonAST.JObject sparkMlParamsToCatBoostJsonParams = Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams(this, Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams$default$2());
        Tuple3<Pool, Pool[], CtrsContext> addEstimatedCtrFeatures = addEstimatedCtrFeatures(pool, poolArr, sparkMlParamsToCatBoostJsonParams, addEstimatedCtrFeatures$default$4(), addEstimatedCtrFeatures$default$5());
        if (addEstimatedCtrFeatures == null) {
            throw new MatchError(addEstimatedCtrFeatures);
        }
        Tuple3 tuple3 = new Tuple3((Pool) addEstimatedCtrFeatures._1(), (Pool[]) addEstimatedCtrFeatures._2(), (CtrsContext) addEstimatedCtrFeatures._3());
        return new Tuple3<>((Pool) tuple3._1(), (Pool[]) tuple3._2(), new CatBoostTrainingContext((CtrsContext) tuple3._3(), sparkMlParamsToCatBoostJsonParams, new TVector_i8()));
    }

    Model createModel(TFullModel tFullModel);

    default Model train(Dataset<?> dataset) {
        Pool pool = new Pool(dataset);
        copyValues(pool, copyValues$default$2());
        return fit(pool, fit$default$2());
    }

    default Model fit(Pool pool, Pool[] poolArr) {
        Pool quantize;
        TFullModel nativeModelResult;
        Helpers$.MODULE$.checkParamsCompatibility(getClass().getName(), this, "trainPool", pool);
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), poolArr.length).foreach$mVc$sp(i -> {
            Helpers$.MODULE$.checkParamsCompatibility(this.getClass().getName(), this, new StringBuilder(10).append("evalPool #").append(i).toString(), poolArr[i]);
        });
        SparkSession sparkSession = pool.data().sparkSession();
        if (pool.isQuantized()) {
            quantize = pool;
        } else {
            QuantizationParams quantizationParams = new QuantizationParams();
            copyValues(quantizationParams, copyValues$default$2());
            ((Logging) this).logInfo(() -> {
                return "fit. schedule quantization for train dataset";
            });
            quantize = pool.quantize(quantizationParams);
        }
        Pool pool2 = quantize;
        IntRef create = IntRef.create(0);
        Pool[] poolArr2 = (Pool[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(poolArr), pool3 -> {
            create.elem++;
            if (pool3.isQuantized()) {
                return pool3;
            }
            ((Logging) this).logInfo(() -> {
                return new StringBuilder(45).append("fit. schedule quantization for eval dataset #").append(create.elem - 1).toString();
            });
            return pool3.quantize(pool2.quantizedFeaturesInfo());
        }, ClassTag$.MODULE$.apply(Pool.class));
        Tuple3<Pool, Pool[], CatBoostTrainingContext> preprocessBeforeTraining = preprocessBeforeTraining(pool2, poolArr2);
        if (preprocessBeforeTraining == null) {
            throw new MatchError(preprocessBeforeTraining);
        }
        Tuple3 tuple3 = new Tuple3((Pool) preprocessBeforeTraining._1(), (Pool[]) preprocessBeforeTraining._2(), (CatBoostTrainingContext) preprocessBeforeTraining._3());
        Pool pool4 = (Pool) tuple3._1();
        Pool[] poolArr3 = (Pool[]) tuple3._2();
        CatBoostTrainingContext catBoostTrainingContext = (CatBoostTrainingContext) tuple3._3();
        int unboxToInt = BoxesRunTime.unboxToInt(get(((TrainingParamsTrait) this).sparkPartitionCount()).getOrElse(() -> {
            return SparkHelpers$.MODULE$.getWorkerCount(sparkSession);
        }));
        ((Logging) this).logInfo(() -> {
            return new StringBuilder(20).append("fit. partitionCount=").append(unboxToInt).toString();
        });
        ((Logging) this).logInfo(() -> {
            return "fit. train.prepareDatasetForTraining: start";
        });
        DatasetForTraining prepareDatasetForTraining = DataHelpers$.MODULE$.prepareDatasetForTraining(pool4, (byte) 0, unboxToInt);
        ((Logging) this).logInfo(() -> {
            return "fit. train.prepareDatasetForTraining: finish";
        });
        DatasetForTraining[] datasetForTrainingArr = (DatasetForTraining[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zipWithIndex$extension(Predef$.MODULE$.refArrayOps(poolArr3))), tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Pool pool5 = (Pool) tuple2._1();
            int _2$mcI$sp = tuple2._2$mcI$sp();
            ((Logging) this).logInfo(() -> {
                return new StringBuilder(44).append("fit. eval #").append(_2$mcI$sp).append(".prepareDatasetForTraining: start").toString();
            });
            DatasetForTraining prepareDatasetForTraining2 = DataHelpers$.MODULE$.prepareDatasetForTraining(pool5, (byte) (_2$mcI$sp + 1), unboxToInt);
            ((Logging) this).logInfo(() -> {
                return new StringBuilder(45).append("fit. eval #").append(_2$mcI$sp).append(".prepareDatasetForTraining: finish").toString();
            });
            return prepareDatasetForTraining2;
        }, ClassTag$.MODULE$.apply(DatasetForTraining.class));
        String precomputedOnlineCtrMetaDataAsJsonString = catBoostTrainingContext.ctrsContext() != null ? catBoostTrainingContext.ctrsContext().precomputedOnlineCtrMetaDataAsJsonString() : null;
        CatBoostMasterWrapper apply = CatBoostMasterWrapper$.MODULE$.apply(prepareDatasetForTraining, Predef$.MODULE$.copyArrayToImmutableIndexedSeq(datasetForTrainingArr), JsonMethods$.MODULE$.compact(catBoostTrainingContext.catBoostJsonParams()), precomputedOnlineCtrMetaDataAsJsonString);
        Duration duration = (Duration) getOrDefault(((TrainingParamsTrait) this).connectTimeout());
        Duration duration2 = (Duration) getOrDefault(((TrainingParamsTrait) this).workerInitializationTimeout());
        int unboxToInt2 = BoxesRunTime.unboxToInt(getOrDefault(((TrainingParamsTrait) this).workerMaxFailures()));
        CatBoostWorkers apply2 = CatBoostWorkers$.MODULE$.apply(sparkSession, unboxToInt, duration, duration2, BoxesRunTime.unboxToInt(getOrDefault(((TrainingParamsTrait) this).workerListeningPort())), prepareDatasetForTraining, Predef$.MODULE$.copyArrayToImmutableIndexedSeq(datasetForTrainingArr), catBoostTrainingContext.catBoostJsonParams(), catBoostTrainingContext.serializedLabelConverter(), precomputedOnlineCtrMetaDataAsJsonString, apply.savedPoolsFuture());
        ExecutorService newCachedThreadPool = Executors.newCachedThreadPool();
        try {
            Breaks$.MODULE$.breakable(() -> {
                while (true) {
                    TrainingDriver trainingDriver = new TrainingDriver(BoxesRunTime.unboxToInt(this.getOrDefault(((TrainingParamsTrait) this).trainingDriverListeningPort())), unboxToInt, (Function1<WorkerInfo[], BoxedUnit>) workerInfoArr -> {
                        apply.trainCallback(workerInfoArr);
                        return BoxedUnit.UNIT;
                    }, duration, duration2, TrainingDriver$.MODULE$.$lessinit$greater$default$6(), TrainingDriver$.MODULE$.$lessinit$greater$default$7());
                    try {
                        int listeningPort = trainingDriver.getListeningPort();
                        ((Logging) this).logInfo(() -> {
                            return new StringBuilder(37).append("fit. TrainingDriver listening port = ").append(listeningPort).toString();
                        });
                        ((Logging) this).logInfo(() -> {
                            return "fit. Training started";
                        });
                        ExecutorCompletionService<BoxedUnit> executorCompletionService = new ExecutorCompletionService<>(newCachedThreadPool);
                        try {
                            ai.catboost.spark.impl.Helpers$.MODULE$.waitForTwoFutures(executorCompletionService, executorCompletionService.submit(trainingDriver, BoxedUnit.UNIT), "master", executorCompletionService.submit(new Runnable(null, apply2, listeningPort) { // from class: ai.catboost.spark.CatBoostPredictorTrait$$anon$1
                                private final CatBoostWorkers workers$1;
                                private final int listeningPort$1;

                                @Override // java.lang.Runnable
                                public void run() {
                                    this.workers$1.run(this.listeningPort$1);
                                }

                                {
                                    this.workers$1 = apply2;
                                    this.listeningPort$1 = listeningPort;
                                }
                            }, BoxedUnit.UNIT), "workers");
                            throw Breaks$.MODULE$.break();
                            break;
                        } catch (ExecutionException e) {
                            if (!(e.getCause() instanceof CatBoostWorkersConnectionLostException)) {
                                throw e;
                            }
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            if (apply2.workerFailureCount() >= unboxToInt2) {
                                throw new CatBoostError(new StringBuilder(39).append("CatBoost workers failed at least ").append(unboxToInt2).append(" times").toString());
                            }
                            if (1 == 0) {
                                throw Breaks$.MODULE$.break();
                            }
                            ((Logging) this).log().info("CatBoost master: communication with some of the workers has been lost. Retry training");
                            trainingDriver.close(true, false);
                        }
                    } catch (Throwable th) {
                        trainingDriver.close(true, false);
                        throw th;
                    }
                }
            });
            newCachedThreadPool.shutdown();
            ((Logging) this).logInfo(() -> {
                return "fit. Training finished";
            });
            if (catBoostTrainingContext.ctrsContext() != null) {
                ((Logging) this).logInfo(() -> {
                    return "fit. Add CtrProvider to model";
                });
                nativeModelResult = CtrFeatures$.MODULE$.addCtrProviderToModel(apply.nativeModelResult(), catBoostTrainingContext.ctrsContext(), pool2, poolArr2);
            } else {
                nativeModelResult = apply.nativeModelResult();
            }
            Model createModel = createModel(nativeModelResult);
            pool4.unpersist();
            return createModel;
        } catch (Throwable th) {
            newCachedThreadPool.shutdown();
            throw th;
        }
    }

    default Pool[] fit$default$2() {
        return (Pool[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Pool.class));
    }

    static void $init$(CatBoostPredictorTrait catBoostPredictorTrait) {
    }
}
