package ml.combust.mleap.bundle.ops.classification;

import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.mleap.bundle.ops.MleapOp;
import ml.combust.mleap.core.classification.AbstractLogisticRegressionModel;
import ml.combust.mleap.core.classification.BinaryLogisticRegressionModel;
import ml.combust.mleap.core.classification.LogisticRegressionModel;
import ml.combust.mleap.core.classification.ProbabilisticLogisticsRegressionModel;
import ml.combust.mleap.runtime.MleapContext;
import ml.combust.mleap.runtime.transformer.classification.LogisticRegression;
import ml.combust.mleap.tensor.DenseTensor;
import ml.combust.mleap.tensor.Tensor;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: LogisticRegressionOp.scala */
@ScalaSignature(bytes = "\u0006\u0001y2A!\u0001\u0002\u0001\u001f\t!Bj\\4jgRL7MU3he\u0016\u001c8/[8o\u001fBT!a\u0001\u0003\u0002\u001d\rd\u0017m]:jM&\u001c\u0017\r^5p]*\u0011QAB\u0001\u0004_B\u001c(BA\u0004\t\u0003\u0019\u0011WO\u001c3mK*\u0011\u0011BC\u0001\u0006[2,\u0017\r\u001d\u0006\u0003\u00171\tqaY8nEV\u001cHOC\u0001\u000e\u0003\tiGn\u0001\u0001\u0014\u0005\u0001\u0001\u0002\u0003B\t\u0013)ui\u0011\u0001B\u0005\u0003'\u0011\u0011q!\u00147fCB|\u0005\u000f\u0005\u0002\u001675\taC\u0003\u0002\u0004/)\u0011\u0001$G\u0001\fiJ\fgn\u001d4pe6,'O\u0003\u0002\u001b\u0011\u00059!/\u001e8uS6,\u0017B\u0001\u000f\u0017\u0005IaunZ5ti&\u001c'+Z4sKN\u001c\u0018n\u001c8\u0011\u0005y\u0011S\"A\u0010\u000b\u0005\r\u0001#BA\u0011\t\u0003\u0011\u0019wN]3\n\u0005\rz\"a\u0006'pO&\u001cH/[2SK\u001e\u0014Xm]:j_:lu\u000eZ3m\u0011\u0015)\u0003\u0001\"\u0001'\u0003\u0019a\u0014N\\5u}Q\tq\u0005\u0005\u0002)\u00015\t!\u0001C\u0004+\u0001\t\u0007I\u0011I\u0016\u0002\u000b5{G-\u001a7\u0016\u00031\u0002B!L\u00194;5\taF\u0003\u00020a\u0005\u0011q\u000e\u001d\u0006\u0003\u000f)I!A\r\u0018\u0003\u000f=\u0003Xj\u001c3fYB\u0011A'N\u0007\u00023%\u0011a'\u0007\u0002\r\u001b2,\u0017\r]\"p]R,\u0007\u0010\u001e\u0005\u0007q\u0001\u0001\u000b\u0011\u0002\u0017\u0002\r5{G-\u001a7!\u0011\u0015Q\u0004\u0001\"\u0011<\u0003\u0015iw\u000eZ3m)\tiB\bC\u0003>s\u0001\u0007A#\u0001\u0003o_\u0012,\u0007")
/* loaded from: input_file:ml/combust/mleap/bundle/ops/classification/LogisticRegressionOp.class */
public class LogisticRegressionOp extends MleapOp<LogisticRegression, LogisticRegressionModel> {
    private final OpModel<MleapContext, LogisticRegressionModel> Model;

    @Override // ml.combust.bundle.op.OpNode
    public OpModel<MleapContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    @Override // ml.combust.bundle.op.OpNode
    public LogisticRegressionModel model(LogisticRegression logisticRegression) {
        return logisticRegression.model();
    }

    public LogisticRegressionOp() {
        super(ClassTag$.MODULE$.apply(LogisticRegression.class));
        this.Model = new OpModel<MleapContext, LogisticRegressionModel>(this) { // from class: ml.combust.mleap.bundle.ops.classification.LogisticRegressionOp$$anon$1
            private final Class<LogisticRegressionModel> klazz = LogisticRegressionModel.class;

            @Override // ml.combust.bundle.op.OpModel
            public Class<LogisticRegressionModel> klazz() {
                return this.klazz;
            }

            @Override // ml.combust.bundle.op.OpModel
            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.logistic_regression();
            }

            @Override // ml.combust.bundle.op.OpModel
            public Model store(Model model, LogisticRegressionModel logisticRegressionModel, BundleContext<MleapContext> bundleContext) {
                Model model2 = (Model) model.withValue("num_classes", Value$.MODULE$.m1848long(logisticRegressionModel.numClasses()));
                if (!logisticRegressionModel.isMultinomial()) {
                    return (Model) ((HasAttributes) ((HasAttributes) model2.withValue("coefficients", Value$.MODULE$.vector(logisticRegressionModel.binaryModel().coefficients().toArray(), ClassTag$.MODULE$.Double()))).withValue("intercept", Value$.MODULE$.m1850double(logisticRegressionModel.binaryModel().intercept()))).withValue("threshold", Value$.MODULE$.m1850double(logisticRegressionModel.binaryModel().threshold()));
                }
                ProbabilisticLogisticsRegressionModel multinomialModel = logisticRegressionModel.multinomialModel();
                Matrix coefficientMatrix = multinomialModel.coefficientMatrix();
                return (Model) ((HasAttributes) ((HasAttributes) model2.withValue("coefficient_matrix", Value$.MODULE$.tensor(new DenseTensor(coefficientMatrix.toArray(), (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{coefficientMatrix.numRows(), coefficientMatrix.numCols()})), ClassTag$.MODULE$.Double())))).withValue("intercept_vector", Value$.MODULE$.vector(multinomialModel.interceptVector().toArray(), ClassTag$.MODULE$.Double()))).withValue("thresholds", multinomialModel.thresholds().map(new LogisticRegressionOp$$anon$1$$anonfun$store$1(this)).map(new LogisticRegressionOp$$anon$1$$anonfun$store$2(this)));
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // ml.combust.bundle.op.OpModel
            public LogisticRegressionModel load(Model model, BundleContext<MleapContext> bundleContext) {
                AbstractLogisticRegressionModel binaryLogisticRegressionModel;
                if (model.value("num_classes").getLong() > 2) {
                    Tensor tensor = model.value("coefficient_matrix").getTensor();
                    binaryLogisticRegressionModel = new ProbabilisticLogisticsRegressionModel(Matrices$.MODULE$.dense(BoxesRunTime.unboxToInt(tensor.dimensions().mo3163head()), BoxesRunTime.unboxToInt(tensor.dimensions().mo3160apply(1)), (double[]) tensor.toArray()), Vectors$.MODULE$.dense((double[]) model.value("intercept_vector").getTensor().toArray()), model.getValue("thresholds").map(new LogisticRegressionOp$$anon$1$$anonfun$2(this)));
                } else {
                    binaryLogisticRegressionModel = new BinaryLogisticRegressionModel(Vectors$.MODULE$.dense((double[]) model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble(), BoxesRunTime.unboxToDouble(model.getValue("threshold").map(new LogisticRegressionOp$$anon$1$$anonfun$3(this)).getOrElse(new LogisticRegressionOp$$anon$1$$anonfun$1(this))));
                }
                return new LogisticRegressionModel(binaryLogisticRegressionModel);
            }
        };
    }
}
