package org.apache.spark.ml.bundle.ops.classification;

import com.fasterxml.jackson.annotation.JsonProperty;
import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.NodeShape;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import org.apache.spark.ml.bundle.ParamSpec;
import org.apache.spark.ml.bundle.ParamSpec$;
import org.apache.spark.ml.bundle.SimpleParamSpec;
import org.apache.spark.ml.bundle.SimpleSparkOp;
import org.apache.spark.ml.bundle.SparkBundleContext;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import scala.Array$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.IntRef;

/* compiled from: GBTClassifierOpV20.scala */
@ScalaSignature(bytes = "\u0006\u0001\u001d4A!\u0001\u0002\u0001#\t\u0011rI\u0011+DY\u0006\u001c8/\u001b4jKJ|\u0005O\u0016\u001a1\u0015\t\u0019A!\u0001\bdY\u0006\u001c8/\u001b4jG\u0006$\u0018n\u001c8\u000b\u0005\u00151\u0011aA8qg*\u0011q\u0001C\u0001\u0007EVtG\r\\3\u000b\u0005%Q\u0011AA7m\u0015\tYA\"A\u0003ta\u0006\u00148N\u0003\u0002\u000e\u001d\u00051\u0011\r]1dQ\u0016T\u0011aD\u0001\u0004_J<7\u0001A\n\u0003\u0001I\u00012a\u0005\u000b\u0017\u001b\u00051\u0011BA\u000b\u0007\u00055\u0019\u0016.\u001c9mKN\u0003\u0018M]6PaB\u0011q#G\u0007\u00021)\u00111\u0001C\u0005\u00035a\u0011ac\u0012\"U\u00072\f7o]5gS\u000e\fG/[8o\u001b>$W\r\u001c\u0005\u00069\u0001!\t!H\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003y\u0001\"a\b\u0001\u000e\u0003\tAq!\t\u0001C\u0002\u0013\u0005#%A\u0003N_\u0012,G.F\u0001$!\u0011!3&\f\f\u000e\u0003\u0015R!AJ\u0014\u0002\u0005=\u0004(BA\u0004)\u0015\tI#&A\u0004d_6\u0014Wo\u001d;\u000b\u0003%I!\u0001L\u0013\u0003\u000f=\u0003Xj\u001c3fYB\u00111CL\u0005\u0003_\u0019\u0011!c\u00159be.\u0014UO\u001c3mK\u000e{g\u000e^3yi\"1\u0011\u0007\u0001Q\u0001\n\r\na!T8eK2\u0004\u0003\"B\u001a\u0001\t\u0003\"\u0014!C:qCJ\\Gj\\1e)\u00111R'Q%\t\u000bY\u0012\u0004\u0019A\u001c\u0002\u0007ULG\r\u0005\u00029}9\u0011\u0011\bP\u0007\u0002u)\t1(A\u0003tG\u0006d\u0017-\u0003\u0002>u\u00051\u0001K]3eK\u001aL!a\u0010!\u0003\rM#(/\u001b8h\u0015\ti$\bC\u0003Ce\u0001\u00071)A\u0003tQ\u0006\u0004X\r\u0005\u0002E\u000f6\tQI\u0003\u0002GO\u0005\u0019Am\u001d7\n\u0005!+%!\u0003(pI\u0016\u001c\u0006.\u00199f\u0011\u0015Q%\u00071\u0001\u0017\u0003\u0015iw\u000eZ3m\u0011\u0015a\u0005\u0001\"\u0011N\u0003-\u0019\b/\u0019:l\u0013:\u0004X\u000f^:\u0015\u00059k\u0006cA(X5:\u0011\u0001+\u0016\b\u0003#Rk\u0011A\u0015\u0006\u0003'B\ta\u0001\u0010:p_Rt\u0014\"A\u001e\n\u0005YS\u0014a\u00029bG.\fw-Z\u0005\u00031f\u00131aU3r\u0015\t1&\b\u0005\u0002\u00147&\u0011AL\u0002\u0002\n!\u0006\u0014\u0018-\\*qK\u000eDQAX&A\u0002Y\t1a\u001c2k\u0011\u0015\u0001\u0007\u0001\"\u0011b\u00031\u0019\b/\u0019:l\u001fV$\b/\u001e;t)\t\u0011g\rE\u0002P/\u000e\u0004\"a\u00053\n\u0005\u00154!aD*j[BdW\rU1sC6\u001c\u0006/Z2\t\u000by{\u0006\u0019\u0001\f")
/* loaded from: input_file:org/apache/spark/ml/bundle/ops/classification/GBTClassifierOpV20.class */
public class GBTClassifierOpV20 extends SimpleSparkOp<GBTClassificationModel> {
    private final OpModel<SparkBundleContext, GBTClassificationModel> Model;

    @Override // ml.combust.bundle.op.OpNode
    public OpModel<SparkBundleContext, GBTClassificationModel> Model() {
        return this.Model;
    }

    @Override // org.apache.spark.ml.bundle.SimpleSparkOp
    public GBTClassificationModel sparkLoad(String str, NodeShape nodeShape, GBTClassificationModel gBTClassificationModel) {
        return new GBTClassificationModel(str, gBTClassificationModel.trees(), gBTClassificationModel.treeWeights(), gBTClassificationModel.numFeatures());
    }

    @Override // org.apache.spark.ml.bundle.SimpleSparkOp
    public Seq<ParamSpec> sparkInputs(GBTClassificationModel gBTClassificationModel) {
        return (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new SimpleParamSpec[]{ParamSpec$.MODULE$.apply(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("features"), gBTClassificationModel.featuresCol()))}));
    }

    @Override // org.apache.spark.ml.bundle.SimpleSparkOp
    public Seq<SimpleParamSpec> sparkOutputs(GBTClassificationModel gBTClassificationModel) {
        return (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new SimpleParamSpec[]{ParamSpec$.MODULE$.apply(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("prediction"), gBTClassificationModel.predictionCol()))}));
    }

    public GBTClassifierOpV20() {
        super(ClassTag$.MODULE$.apply(GBTClassificationModel.class));
        this.Model = new OpModel<SparkBundleContext, GBTClassificationModel>(this) { // from class: org.apache.spark.ml.bundle.ops.classification.GBTClassifierOpV20$$anon$1
            private final Class<GBTClassificationModel> klazz = GBTClassificationModel.class;

            @Override // ml.combust.bundle.op.OpModel
            public Class<GBTClassificationModel> klazz() {
                return this.klazz;
            }

            @Override // ml.combust.bundle.op.OpModel
            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.gbt_classifier();
            }

            @Override // ml.combust.bundle.op.OpModel
            public Model store(Model model, GBTClassificationModel gBTClassificationModel, BundleContext<SparkBundleContext> bundleContext) {
                return (Model) ((HasAttributes) ((HasAttributes) ((HasAttributes) model.withValue("num_features", Value$.MODULE$.m1848long(gBTClassificationModel.numFeatures()))).withValue("num_classes", Value$.MODULE$.m1848long(2L))).withValue("tree_weights", Value$.MODULE$.doubleList(Predef$.MODULE$.wrapDoubleArray(gBTClassificationModel.treeWeights())))).withValue("trees", Value$.MODULE$.stringList(Predef$.MODULE$.wrapRefArray((String[]) Predef$.MODULE$.refArrayOps(gBTClassificationModel.trees()).map(new GBTClassifierOpV20$$anon$1$$anonfun$1(this, bundleContext, IntRef.create(0)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))));
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // ml.combust.bundle.op.OpModel
            public GBTClassificationModel load(Model model, BundleContext<SparkBundleContext> bundleContext) {
                if (model.value("num_classes").getLong() != 2) {
                    throw new IllegalArgumentException("MLeap only supports binary logistic regression");
                }
                int i = (int) model.value("num_features").getLong();
                return new GBTClassificationModel(JsonProperty.USE_DEFAULT_NAME, (DecisionTreeRegressionModel[]) ((TraversableOnce) model.value("trees").getStringList().map(new GBTClassifierOpV20$$anon$1$$anonfun$2(this, bundleContext), Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(DecisionTreeRegressionModel.class)), (double[]) model.value("tree_weights").getDoubleList().toArray(ClassTag$.MODULE$.Double()), i);
            }
        };
    }
}
