diff --git a/README.md b/README.md
index eaf65d06256a534d682377be1b9ad009340aa424..9f4af39fc2c12f80d6e8524040469386ac4f03df 100644
--- a/README.md
+++ b/README.md
@@ -5,9 +5,9 @@
Introduction
============
-The machine learning algorithm library running on Kunpeng processors is an acceleration library that provides a rich set of high-level tools for machine learning algorithms. It is based on the original APIs of Apache [Spark 3.1.1](https://github.com/apache/spark/tree/v3.1.1). The acceleration library for greatly improves the computing power in big data scenarios.
+The machine learning algorithm library running on Kunpeng processors is an acceleration library that provides a rich set of high-level tools for machine learning algorithms. It is based on the original APIs of Apache [Spark 3.3.1](https://github.com/apache/spark/tree/v3.3.1). The acceleration library for greatly improves the computing power in big data scenarios.
-The library provides 10 machine learning algorithms: latent dirichlet allocation (LDA), prefix-projected pattern prowth (Prefix-Span), alternating least squares (ALS), K-nearest neighbors (KNN), Density-based spatial clustering of applicaitons with noise (DBSCAN), random forest classifier (RFC), gradient boosting decision tree (GBDT), decision tree (DT), decision tree bucket(DTB) and Word2Vec. You can find the latest documentation on the project web page. This README file contains only basic setup instructions.
+The library provides 4 machine learning algorithms: Density-based spatial clustering of applicaitons with noise (DBSCAN), Support Vector Machines)SVM), decision tree bucket(DTB) and Word2Vec. You can find the latest documentation on the project web page. This README file contains only basic setup instructions.
You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions.
@@ -21,9 +21,9 @@ Building And Packageing
mvn clean package
-(2) Obtain "boostkit-ml-core_2.12-2.2.0-spark3.1.1.jar" under the "Spark-ml-algo-lib/ml-core/target" directory.
+(2) Obtain "boostkit-ml-core_2.12-3.0.0-spark3.3.1.jar" under the "Spark-ml-algo-lib/ml-core/target" directory.
- Obtain "boostkit-ml-acc_2.12-2.2.0-spark3.1.1.jar" under the "Spark-ml-algo-lib/ml-accelerator/target" directory.
+ Obtain "boostkit-ml-acc_2.12-3.0.0-spark3.3.1.jar" under the "Spark-ml-algo-lib/ml-accelerator/target" directory.
Contribution Guidelines
diff --git a/ml-accelerator/pom.xml b/ml-accelerator/pom.xml
index 800e7e164ec5e345154ec169fd8aafbcb1de277d..5c7222659621e70d8be74f4bdf08171c26d8464b 100644
--- a/ml-accelerator/pom.xml
+++ b/ml-accelerator/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.2.0
+ 3.0.0
4.0.0
boostkit-ml-acc_2.12
- 2.2.0
+ 3.0.0
${project.artifactId}
Spark ml algo accelerator
@@ -17,14 +17,20 @@
boostkit-ml-core_2.12
${project.version}
${spark.version}
-
+ provided
org.apache.spark
boostkit-ml-kernel-client_2.12
${project.version}
${spark.version}
- compile
+ provided
+
+
+ org.jpmml
+ pmml-model
+ 1.4.8
+ provided
@@ -44,6 +50,9 @@
+ -unchecked
+ -deprecation
+ -feature
-dependencyfile
${project.build.directory}/.scala_dependencies
diff --git a/ml-accelerator/src/main/scala/com/intel/ssg/bdt/nlp/CRF.scala b/ml-accelerator/src/main/scala/com/intel/ssg/bdt/nlp/CRF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..48c92f0e2eff6b0b3a05b1c1c20ace9183f59c87
--- /dev/null
+++ b/ml-accelerator/src/main/scala/com/intel/ssg/bdt/nlp/CRF.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+
+trait Regularization
+
+case object L1 extends Regularization
+
+case object L2 extends Regularization
+
+/**
+ * CRF with support for multiple parallel runs
+ * L2 regParam = 1/(2.0 * sigma**2)
+ */
+class CRF private (
+ private var freq: Int,
+ private var regParam: Double,
+ private var maxIterations: Int,
+ private var tolerance: Double,
+ private var regularization: Regularization) extends Serializable with Logging {
+ // boost-kit parameter
+ private var compLevel: Int = 0
+ private var nThread: Int = 1
+ private var globalStageIterFraction: Double = 1.0
+ private var commFreeSplit: Int = 0
+ private var commFreeToleranceFactor: Double = 5.0
+ private var calcAccuracy: Boolean = false
+
+ def this() = this(
+ freq = 1,
+ regParam = 0.5,
+ maxIterations = 1000,
+ tolerance = 1E-3,
+ regularization = L2)
+
+ def setRegParam(regParam: Double): this.type = {
+ this.regParam = regParam
+ this
+ }
+
+ // set features that frequency greater than given value
+ def setFreq(freq: Int): this.type = {
+ this.freq = freq
+ this
+ }
+
+ def setMaxIterations(maxIterations: Int): this.type = {
+ this.maxIterations = maxIterations
+ this
+ }
+
+ def setTolerance(tol: Double): this.type = {
+ this.tolerance = tol
+ this
+ }
+
+ def setRegularization(reg: Regularization): this.type = {
+ this.regularization = reg
+ this
+ }
+
+ // set if need to calculate model's accuracy
+ // this requires testArrayWithLabel and testArrayWithoutLabel are given
+ def setCalcAccuracy(ca: Boolean): this.type = {
+ this.calcAccuracy = ca
+ this
+ }
+
+ // set data compression level
+ def setCompLevel(compLevel: Int): this.type = {
+ require(compLevel >= 0 && compLevel <= 3,
+ s"compLevel must be [0, 1, 2, 3] but got $compLevel")
+ this.compLevel = compLevel
+ this
+ }
+
+ // set global Stage Iteration Fraction
+ def setGlobalStageIterFraction(globalStageIterFraction: Double): this.type = {
+ require(globalStageIterFraction >= 0.0 && globalStageIterFraction <= 1.0,
+ s"globalStageIterFraction must be [0.0, 1.0] but got $globalStageIterFraction")
+ this.globalStageIterFraction = globalStageIterFraction
+ this
+ }
+
+ // set number of partitions for local stage
+ def setCommFreeSplit(commFreeSplit: Int): this.type = {
+ require(commFreeSplit >= 0,
+ s"commFreeSplit must be greater or equal than 0, but got $commFreeSplit")
+ this.commFreeSplit = commFreeSplit
+ this
+ }
+
+ // set Tolerance Factor for local stage
+ def setCommFreeToleranceFactor(commFreeToleranceFactor: Double): this.type = {
+ require(commFreeToleranceFactor >= 1.0 && commFreeToleranceFactor <= 10.0,
+ s"commFreeToleranceFactor must be in range of [1.0, 10.0], " +
+ s"but got $commFreeToleranceFactor")
+ this.commFreeToleranceFactor = commFreeToleranceFactor
+ this
+ }
+
+ // set number of thread
+ def setNumThread(numThread: Int): this.type = {
+ require(numThread >= 1, s"numThread must be greater than 0 but got $numThread")
+ this.nThread = numThread
+ this
+ }
+
+ /**
+ * train CRF model
+ *
+ * @param template the template to train the model
+ * @param trains the source for the training
+ * @param testArrayWithLabel test dataset with label
+ * @param testArrayWithoutLabel test dataset without label
+ * @return the model of the source
+ */
+ def runCRF(
+ template: Array[String],
+ trains: RDD[Sequence],
+ testArrayWithLabel: Array[Sequence] = Array[Sequence](),
+ testArrayWithoutLabel: Array[Sequence] = Array[Sequence]()): CRFModel = {
+ val featureIdx = new FeatureIndex()
+ featureIdx.openTemplate(template)
+ featureIdx.openTagSetDist(trains)
+
+ val bcFeatureIdxI: Broadcast[FeatureIndex] = trains.context.broadcast(featureIdx)
+ val taggers = trains.map(train => {
+ val tagger: Tagger = new Tagger(bcFeatureIdxI.value.labels.size, LearnMode)
+ tagger.read(train, bcFeatureIdxI.value)
+ tagger
+ })
+
+ featureIdx.buildDictionaryDist(taggers, bcFeatureIdxI, freq)
+
+ val bcFeatureIdxII = trains.context.broadcast(featureIdx)
+ val taggerList: RDD[Tagger] = taggers.map(bcFeatureIdxII.value.buildFeatures).cache()
+
+ val model = runAlgorithm(taggerList, featureIdx, testArrayWithLabel, testArrayWithoutLabel)
+ taggerList.unpersist()
+
+ model
+ }
+
+ /**
+ *
+ * @param taggers the tagger in the template
+ * @param featureIdx the index of the feature
+ */
+ private def runAlgorithm(
+ taggers: RDD[Tagger],
+ featureIdx: FeatureIndex,
+ testArrayWithLabel: Array[Sequence] = Array[Sequence](),
+ testArrayWithoutLabel: Array[Sequence] = Array[Sequence]()): CRFModel = {
+
+ logInfo("Starting CRF Iterations ( sentences: %d, features: %d, labels: %d )"
+ .format(taggers.count(), featureIdx.maxID, featureIdx.labels.length))
+
+ var updater: UpdaterCRF = null
+ regularization match {
+ case L1 =>
+ updater = new L1Updater
+ case L2 =>
+ updater = new L2Updater
+ case _ =>
+ throw new Exception("only support L1-CRF and L2-CRF now")
+ }
+
+ if (compLevel == 0 && nThread == 1) {
+ featureIdx.alpha = new CRFWithLBFGS(new CRFGradient, updater)
+ .setRegParam(regParam)
+ .setConvergenceTol(tolerance)
+ .setNumIterations(maxIterations)
+ .optimizer(taggers, featureIdx.initAlpha())
+ } else if (globalStageIterFraction != 1.0) {
+ val CRFObj = new CRFWithLBFGS(new CRFGradient, updater, compLevel, nThread)
+ .setRegParam(regParam)
+ .setConvergenceTol(tolerance * commFreeToleranceFactor)
+ .setNumIterations((maxIterations * (1.0 - globalStageIterFraction)).toInt)
+
+ featureIdx.alpha = runTwoStageCRF(CRFObj, taggers, featureIdx)
+ } else {
+ featureIdx.alpha = new CRFWithLBFGS(new CRFGradient, updater, compLevel, nThread)
+ .setRegParam(regParam)
+ .setConvergenceTol(tolerance)
+ .setNumIterations(maxIterations)
+ .optimizerX(taggers, featureIdx.initAlpha())
+ }
+
+ // calculate the accuracy faster
+ if (calcAccuracy && testArrayWithLabel.length == testArrayWithoutLabel.length) {
+ if (testArrayWithLabel.length != 0) {
+ Accuracy.calc(featureIdx, testArrayWithLabel, testArrayWithoutLabel)
+ } else {
+ logInfo(s"test dataset not given.")
+ }
+ }
+
+ featureIdx.saveModel
+ }
+
+ private def runTwoStageCRF(
+ CRFWithLBFGSObj: CRFWithLBFGS,
+ taggers: RDD[Tagger],
+ featureIdx: FeatureIndex): BDV[Double] = {
+ val numParts = taggers.getNumPartitions
+ val taggersStage1 = if (commFreeSplit != 0) {
+ taggers.repartition(commFreeSplit).cache()
+ } else {
+ taggers.cache()
+ }
+
+ // stage1: local training
+ val weightsIdsRDD: RDD[BDV[Double]] = taggersStage1.mapPartitions { tgr =>
+ Iterator(CRFWithLBFGSObj.optimizerLocal(tgr, featureIdx.initAlpha()))
+ }
+
+ // sum up local weights and then average
+ val weightsAveraged = weightsIdsRDD.reduce((weight1, weight2) => weight1 +:+ weight2)
+ featureIdx.alpha = weightsAveraged *:* (1 / taggersStage1.getNumPartitions.toDouble)
+
+ taggersStage1.unpersist()
+
+ // reset the tolerance and number of iterations for global stage
+ val globalIter = (maxIterations * globalStageIterFraction).toInt
+ CRFWithLBFGSObj.setConvergenceTol(tolerance).setNumIterations(globalIter)
+
+ val taggersStage2 = taggers.repartition(numParts).cache()
+
+ // stage2: global training
+ featureIdx.alpha = CRFWithLBFGSObj.optimizerX(taggersStage2, featureIdx.alpha)
+ taggersStage2.unpersist()
+
+ featureIdx.alpha
+ }
+}
+
+
+object Accuracy extends Logging {
+ def calc(
+ featureIdx: FeatureIndex,
+ testArrayWithLabel: Array[Sequence],
+ testArrayWithoutLabel: Array[Sequence]): Double = {
+ val results = testArrayWithoutLabel.map(testCRF(_, featureIdx))
+ var score = 0
+ var i = 0
+ for (r <- results) {
+ score += r.compare(testArrayWithLabel(i))
+ i += 1
+ }
+ val total = testArrayWithoutLabel.map(_.toArray.length).sum
+
+ logInfo(f"==== Prediction Accuracy: $score / $total = ${score / total.toDouble} ====")
+
+ score / total.toDouble
+ }
+
+ private def testCRF(test: Sequence, featureIdx: FeatureIndex): Sequence = {
+ val tagger = new Tagger(featureIdx.labels.size, TestMode)
+ tagger.read(test, featureIdx)
+ featureIdx.buildFeatures(tagger)
+ tagger.parse(featureIdx.alpha, None)
+
+ Sequence(test.toArray.map { x =>
+ Token.put(featureIdx.labels(tagger.result(test.toArray.indexOf(x))), x.tags)
+ })
+ }
+}
diff --git a/ml-accelerator/src/main/scala/com/intel/ssg/bdt/nlp/CRFWithLBFGS.scala b/ml-accelerator/src/main/scala/com/intel/ssg/bdt/nlp/CRFWithLBFGS.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a2651b04df6b5ba194e8cd82b38a0e3a75bbf7ea
--- /dev/null
+++ b/ml-accelerator/src/main/scala/com/intel/ssg/bdt/nlp/CRFWithLBFGS.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import scala.collection.mutable
+import breeze.linalg.{DenseVector => BDV, sum => Bsum}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.linalg.{Vector => SparkVector}
+import org.apache.spark.mllib.optimization._
+import org.apache.spark.nlp.{CRFGradientX, CostFunX}
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable.ArrayBuffer
+import scala.language.existentials
+
+class CRFWithLBFGS(
+ private var gradient: CRFGradient,
+ private var updater: Updater,
+ private var compLevel: Int = 0,
+ private var nThread: Int = 1)
+ extends LBFGS(gradient: Gradient, updater: Updater) {
+
+ private val numCorrections = 5
+ private var maxNumIterations = 100
+ private var convergenceTol = 1E-4
+ private var regParam = 0.5
+
+ /**
+ * Set the regularization parameter. Default 0.5.
+ */
+ override def setRegParam(regParam: Double): this.type = {
+ this.regParam = regParam
+ this
+ }
+
+ /**
+ * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * This value must be nonnegative. Lower convergence values are less tolerant
+ * and therefore generally cause more iterations to be run.
+ */
+ override def setConvergenceTol(tolerance: Double): this.type = {
+ this.convergenceTol = tolerance
+ this
+ }
+
+ /**
+ * Set the maximal number of iterations for L-BFGS. Default 100.
+ */
+ override def setNumIterations(iters: Int): this.type = {
+ this.maxNumIterations = iters
+ this
+ }
+
+ def optimizer(data: RDD[Tagger], initialWeights: BDV[Double]): BDV[Double] = {
+ CRFWithLBFGS.runLBFGS(data,
+ gradient,
+ updater,
+ numCorrections,
+ convergenceTol,
+ maxNumIterations,
+ regParam,
+ initialWeights)
+ }
+
+ def optimizerLocal(data: Iterator[Tagger], initialWeights: BDV[Double]): BDV[Double] = {
+ CRFWithLBFGS.runLBFGSLocal(data,
+ gradient,
+ updater,
+ numCorrections,
+ convergenceTol,
+ maxNumIterations,
+ regParam,
+ initialWeights)
+ }
+
+ def optimizerX(data: RDD[Tagger], initialWeights: BDV[Double]): BDV[Double] = {
+ val gradientX = new CRFGradientX
+ val processedData = CRFGradientX.dataProcess(data, nThread)
+ CRFWithLBFGS.runLBFGSX(processedData,
+ gradientX,
+ updater,
+ numCorrections,
+ convergenceTol,
+ maxNumIterations,
+ regParam,
+ initialWeights,
+ compLevel,
+ nThread)
+ }
+}
+
+object CRFWithLBFGS extends Logging {
+ def runLBFGS(
+ data: RDD[Tagger],
+ gradient: CRFGradient,
+ updater: Updater,
+ numCorrections: Int,
+ convergenceTol: Double,
+ maxNumIterations: Int,
+ regParam: Double,
+ initialWeights: BDV[Double]): BDV[Double] = {
+
+ val costFun = new CostFun(data, gradient, updater, regParam)
+
+ var lbfgs: BreezeLBFGS[BDV[Double]] = null
+
+ updater match {
+ case updater: L1Updater =>
+ lbfgs = new BreezeOWLQN[Int, BDV[Double]](
+ maxNumIterations,
+ numCorrections,
+ regParam,
+ convergenceTol)
+ case updater: L2Updater =>
+ lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
+ }
+
+ val states = lbfgs.iterations(new CachedDiffFunction[BDV[Double]](costFun), initialWeights)
+
+ val lossHistory = mutable.ArrayBuilder.make[Double]
+ var state = states.next()
+ while (states.hasNext) {
+ lossHistory += state.value
+ state = states.next()
+ }
+
+ logInfo("LBFGS.runLBFGS finished after %s iterations. last 10 losses: %s".format(
+ state.iter, lossHistory.result().takeRight(10).mkString(" -> ")))
+ state.x
+ }
+
+ private def runLBFGSLocal(
+ data: Iterator[Tagger],
+ gradient: CRFGradient,
+ updater: Updater,
+ numCorrections: Int,
+ convergenceTol: Double,
+ maxNumIterations: Int,
+ regParam: Double,
+ initialWeights: BDV[Double]): BDV[Double] = {
+
+ val costFun = new CostFunLocal(data, gradient, updater, regParam)
+
+ var lbfgs: BreezeLBFGS[BDV[Double]] = null
+
+ updater match {
+ case updater: L1Updater =>
+ lbfgs = new BreezeOWLQN[Int, BDV[Double]](
+ maxNumIterations,
+ numCorrections,
+ regParam,
+ convergenceTol)
+ case updater: L2Updater =>
+ lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
+ }
+
+ val states = lbfgs.iterations(new CachedDiffFunction[BDV[Double]](costFun), initialWeights)
+
+ val lossHistory = mutable.ArrayBuilder.make[Double]
+ var state = states.next()
+ while (states.hasNext) {
+ lossHistory += state.value
+ state = states.next()
+ }
+
+ lossHistory += state.value
+
+ logInfo("LBFGS.runLBFGS finished after %s iterations. last 10 losses: %s".format(
+ state.iter, lossHistory.result().takeRight(10).mkString(" -> ")))
+ state.x
+ }
+
+ private def runLBFGSX(
+ data: RDD[Array[ArrayBuffer[Tagger]]],
+ gradient: CRFGradientX,
+ updater: Updater,
+ numCorrections: Int,
+ convergenceTol: Double,
+ maxNumIterations: Int,
+ regParam: Double,
+ initialWeights: BDV[Double],
+ compLevel: Int,
+ nThread: Int): BDV[Double] = {
+
+ val costFunX = new CostFunX(data, gradient, updater, regParam, compLevel, nThread)
+ if (compLevel != 0) {
+ costFunX.setDriverCoreFromSparkConf(data.context)
+ }
+
+ var lbfgs: BreezeLBFGS[BDV[Double]] = null
+
+ updater match {
+ case updater: L1Updater =>
+ lbfgs = new BreezeOWLQN[Int, BDV[Double]](maxNumIterations,
+ numCorrections,
+ regParam,
+ convergenceTol)
+ case updater: L2Updater =>
+ lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
+ }
+
+ val states = lbfgs.iterations(new CachedDiffFunction[BDV[Double]](costFunX), initialWeights)
+
+ val lossHistory = mutable.ArrayBuilder.make[Double]
+ var state = states.next()
+ while (states.hasNext) {
+ lossHistory += state.value
+ state = states.next()
+ }
+
+ logInfo("LBFGS.runLBFGS finished after %s iterations. last 10 losses: %s".format(
+ state.iter, lossHistory.result().takeRight(10).mkString(" -> ")))
+ state.x
+ }
+}
+
+class CRFGradient extends Gradient {
+ def compute(
+ data: SparkVector,
+ label: Double,
+ weights: SparkVector,
+ cumGradient: SparkVector): Double = {
+ throw new Exception("The original compute() method is not supported")
+ }
+
+ def computeCRF(sentences: Iterator[Tagger], weights: BDV[Double]): (BDV[Double], Double) = {
+
+ val expected = BDV.zeros[Double](weights.length)
+ var obj: Double = 0.0
+ while (sentences.hasNext)
+ obj += sentences.next().gradient(expected, weights)
+
+ (expected, obj)
+ }
+}
+
+class L2Updater extends UpdaterCRF {
+ def computeCRF(
+ weightsOld: BDV[Double],
+ gradient: BDV[Double],
+ regParam: Double): (BDV[Double], Double) = {
+ val loss = Bsum(weightsOld *:* weightsOld *:* regParam)
+ gradient :+= weightsOld *:* (regParam * 2.0)
+ (gradient, loss)
+ }
+}
+
+class L1Updater extends UpdaterCRF {
+ def computeCRF(
+ weightsOld: BDV[Double],
+ gradient: BDV[Double],
+ regParam: Double): (BDV[Double], Double) = {
+ (gradient, 0.0)
+ }
+}
+
+private class CostFun(
+ taggers: RDD[Tagger],
+ gradient: CRFGradient,
+ updater: Updater,
+ regParam: Double) extends DiffFunction[BDV[Double]] with Logging with Serializable {
+
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+
+ val start = System.currentTimeMillis
+ val bcWeights = taggers.context.broadcast(weights)
+ lazy val treeDepth = math.ceil(math.log(taggers.partitions.length) / (math.log(2) * 2)).toInt.max(1)
+ val computeRes = taggers.mapPartitions(sentences =>
+ Iterator(gradient.computeCRF(sentences, bcWeights.value))
+ )
+
+ val (expected, obj) = computeRes.treeReduce((p1, p2) => (p1, p2) match {
+ case ((expected1, obj1), (expected2, obj2)) =>
+ (expected1 + expected2, obj1 + obj2)
+ }, treeDepth)
+
+ val (grad, loss) = updater.asInstanceOf[UpdaterCRF].computeCRF(weights, expected, regParam)
+
+ val end = System.currentTimeMillis
+ logInfo(s"Run Time for raw = %f[s]\n".format((end - start) / 1000.0))
+
+ (obj + loss, grad)
+ }
+}
+
+private class CostFunLocal(
+ var taggersOriginal: Iterator[Tagger],
+ gradient: CRFGradient,
+ updater: Updater,
+ regParam: Double) extends DiffFunction[BDV[Double]] with Logging with Serializable {
+
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ val start = System.currentTimeMillis
+ val (taggers, tmp) = taggersOriginal.duplicate
+ taggersOriginal = tmp
+
+ val (expected, obj) = gradient.computeCRF(taggers, weights)
+
+ var grad: BDV[Double] = BDV.zeros[Double](weights.length)
+ var loss: Double = 0.0
+
+ updater match {
+ case updater: UpdaterCRF =>
+ val (gradTmp, lossTmp) = updater.computeCRF(weights, expected, regParam)
+ grad = gradTmp
+ loss = lossTmp
+ }
+
+ val end = System.currentTimeMillis
+ logInfo(s"Run Time for local = %f[s]\n".format((end - start) / 1000.0))
+
+ (obj + loss, grad)
+ }
+}
\ No newline at end of file
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9c22cb19d91fc15875d8cb1d460e3eee73dacc55..0bd00cb5fcad74e26d8102a68396be09e8080658 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -288,7 +288,8 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
+ val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
+ sparkSession.createDataFrame(nodeData).repartition(numDataParts).write.parquet(dataPath)
}
}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
new file mode 100644
index 0000000000000000000000000000000000000000..99f7323002d10af645739f5fe8cb15eacdc8c5b0
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.StaticUtils
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.regression.{FactorizationMachines, FactorizationMachinesParams}
+import org.apache.spark.ml.regression.FactorizationMachines._
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.linalg.{Vector => OldVector}
+import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for FMClassifier.
+ */
+private[classification] trait FMClassifierParams extends ProbabilisticClassifierParams
+ with FactorizationMachinesParams {
+}
+
+/**
+ * Factorization Machines learning algorithm for classification.
+ * It supports normal gradient descent and AdamW solver.
+ *
+ * The implementation is based upon:
+ *
+ * S. Rendle. "Factorization machines" 2010.
+ *
+ * FM is able to estimate interactions even in problems with huge sparsity
+ * (like advertising and recommendation system).
+ * FM formula is:
+ *
+ * $$
+ * \begin{align}
+ * y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i +
+ * \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right)
+ * \end{align}
+ * $$
+ *
+ * First two terms denote global bias and linear term (as same as linear regression),
+ * and last term denotes pairwise interactions term. v_i describes the i-th variable
+ * with k factors.
+ *
+ * FM classification model uses logistic loss which can be solved by gradient descent method, and
+ * regularization terms like L2 are usually added to the loss function to prevent overfitting.
+ *
+ * @note Multiclass labels are not currently supported.
+ */
+@Since("3.0.0")
+class FMClassifier @Since("3.0.0") (
+ @Since("3.0.0") override val uid: String)
+ extends ProbabilisticClassifier[Vector, FMClassifier, FMClassificationModel]
+ with FactorizationMachines with FMClassifierParams with DefaultParamsWritable with Logging {
+
+ @Since("3.0.0")
+ def this() = this(Identifiable.randomUID("fmc"))
+
+ /**
+ * Set the dimensionality of the factors.
+ * Default is 8.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setFactorSize(value: Int): this.type = set(factorSize, value)
+
+ /**
+ * Set whether to fit intercept term.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+ /**
+ * Set whether to fit linear term.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setFitLinear(value: Boolean): this.type = set(fitLinear, value)
+
+ /**
+ * Set the L2 regularization parameter.
+ * Default is 0.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /**
+ * Set the mini-batch fraction parameter.
+ * Default is 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value)
+
+ /**
+ * Set the standard deviation of initial coefficients.
+ * Default is 0.01.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setInitStd(value: Double): this.type = set(initStd, value)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Set the initial step size for the first step (like learning rate).
+ * Default is 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Default is 1E-6.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /**
+ * Set the solver algorithm used for optimization.
+ * Supported options: "gd", "adamW".
+ * Default: "adamW"
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setSolver(value: String): this.type = set(solver, value)
+
+ /**
+ * Set the random seed for weight initialization.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override protected def train(
+ dataset: Dataset[_]): FMClassificationModel = instrumented { instr =>
+ val numClasses = 2
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".train() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
+
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam,
+ miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds)
+ instr.logNumClasses(numClasses)
+
+ val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ instr.logNumFeatures(numFeatures)
+
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
+ val labeledPoint = extractLabeledPoints(dataset, numClasses)
+ val data: RDD[(Double, OldVector)] =
+ labeledPoint.map(x => (x.label + StaticUtils.ZERO_DOUBLE, x.features))
+
+ if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (coefficients, objectiveHistory) = trainImpl(data, numFeatures, LogisticLoss)
+
+ val (intercept, linear, factors) = splitCoefficients(
+ coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear))
+
+ if (handlePersistence) data.unpersist()
+
+ createModel(dataset, intercept, linear, factors, objectiveHistory)
+ }
+
+ private def createModel(
+ dataset: Dataset[_],
+ intercept: Double,
+ linear: Vector,
+ factors: Matrix,
+ objectiveHistory: Array[Double]): FMClassificationModel = {
+ val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+
+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
+ val summary = new FMClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ objectiveHistory)
+ model.setSummary(Some(summary))
+ }
+
+ @Since("3.0.0")
+ override def copy(extra: ParamMap): FMClassifier = defaultCopy(extra)
+}
+
+@Since("3.0.0")
+object FMClassifier extends DefaultParamsReadable[FMClassifier] {
+
+ @Since("3.0.0")
+ override def load(path: String): FMClassifier = super.load(path)
+}
+
+/**
+ * Model produced by [[FMClassifier]]
+ */
+@Since("3.0.0")
+class FMClassificationModel private[classification] (
+ @Since("3.0.0") override val uid: String,
+ @Since("3.0.0") val intercept: Double,
+ @Since("3.0.0") val linear: Vector,
+ @Since("3.0.0") val factors: Matrix)
+ extends ProbabilisticClassificationModel[Vector, FMClassificationModel]
+ with FMClassifierParams with MLWritable
+ with HasTrainingSummary[FMClassificationTrainingSummary]{
+
+ @Since("3.0.0")
+ override val numClasses: Int = 2
+
+ @Since("3.0.0")
+ override val numFeatures: Int = linear.size
+
+ /**
+ * Gets summary of model on training set. An exception is thrown
+ * if `hasSummary` is false.
+ */
+ @Since("3.1.0")
+ override def summary: FMClassificationTrainingSummary = super.summary
+
+ /**
+ * Evaluates the model on a test dataset.
+ *
+ * @param dataset Test dataset to evaluate model on.
+ */
+ @Since("3.1.0")
+ def evaluate(dataset: Dataset[_]): FMClassificationSummary = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+ // Handle possible missing or invalid probability or prediction columns
+ val (summaryModel, probability, predictionColName) = findSummaryModel()
+ new FMClassificationSummaryImpl(summaryModel.transform(dataset),
+ probability, predictionColName, $(labelCol), weightColName)
+ }
+
+ @Since("3.0.0")
+ override def predictRaw(features: Vector): Vector = {
+ val rawPrediction = getRawPrediction(features, intercept, linear, factors)
+ Vectors.dense(Array(-rawPrediction, rawPrediction))
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ dv.values(1) = 1.0 / (1.0 + math.exp(-dv.values(1)))
+ dv.values(0) = 1.0 - dv.values(1)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in FMClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
+ @Since("3.0.0")
+ override def copy(extra: ParamMap): FMClassificationModel = {
+ copyValues(new FMClassificationModel(uid, intercept, linear, factors), extra)
+ }
+
+ @Since("3.0.0")
+ override def write: MLWriter =
+ new FMClassificationModel.FMClassificationModelWriter(this)
+
+ override def toString: String = {
+ s"FMClassificationModel: " +
+ s"uid=${super.toString}, numClasses=$numClasses, numFeatures=$numFeatures, " +
+ s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}"
+ }
+}
+
+@Since("3.0.0")
+object FMClassificationModel extends MLReadable[FMClassificationModel] {
+
+ @Since("3.0.0")
+ override def read: MLReader[FMClassificationModel] = new FMClassificationModelReader
+
+ @Since("3.0.0")
+ override def load(path: String): FMClassificationModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[FMClassificationModel]] */
+ private[FMClassificationModel] class FMClassificationModelWriter(
+ instance: FMClassificationModel) extends MLWriter with Logging {
+
+ private case class Data(
+ intercept: Double,
+ linear: Vector,
+ factors: Matrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.intercept, instance.linear, instance.factors)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class FMClassificationModelReader extends MLReader[FMClassificationModel] {
+
+ private val className = classOf[FMClassificationModel].getName
+
+ override def load(path: String): FMClassificationModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.format("parquet").load(dataPath)
+
+ val Row(intercept: Double, linear: Vector, factors: Matrix) =
+ data.select("intercept", "linear", "factors").head()
+ val model = new FMClassificationModel(metadata.uid, intercept, linear, factors)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+}
+
+/**
+ * Abstraction for FMClassifier results for a given model.
+ */
+sealed trait FMClassificationSummary extends BinaryClassificationSummary
+
+/**
+ * Abstraction for FMClassifier training results.
+ */
+sealed trait FMClassificationTrainingSummary extends FMClassificationSummary with TrainingSummary
+
+/**
+ * FMClassifier results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param scoreCol field in "predictions" which gives the probability of each instance.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ */
+private class FMClassificationSummaryImpl(
+ @transient override val predictions: DataFrame,
+ override val scoreCol: String,
+ override val predictionCol: String,
+ override val labelCol: String,
+ override val weightCol: String)
+ extends FMClassificationSummary
+
+/**
+ * FMClassifier training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param scoreCol field in "predictions" which gives the probability of each instance.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param weightCol field in "predictions" which gives the weight of each instance.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class FMClassificationTrainingSummaryImpl(
+ predictions: DataFrame,
+ scoreCol: String,
+ predictionCol: String,
+ labelCol: String,
+ weightCol: String,
+ override val objectiveHistory: Array[Double])
+ extends FMClassificationSummaryImpl(
+ predictions, scoreCol, predictionCol, labelCol, weightCol)
+ with FMClassificationTrainingSummary
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 37d386b101451012b75d3f0ecabee00d9503e45c..86140028b53cc210a34bd8350f97833c7ccf13f5 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.classification
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
-import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
@@ -382,7 +381,7 @@ class GBTClassificationModel private[ml](
/** Raw prediction for the positive class. */
private def margin(features: Vector): Double = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
- blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
+ BLAS.nativeBLAS.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
}
/** (private[ml]) Convert to a model in the old API */
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b2ee6a13c5ed64df5d2d6bb86dd8bddb291c7bb0
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.collection.mutable
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.{CachedDiffFunction, OWLQNF}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.StaticUtils
+import org.apache.spark.ml.feature._
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.optim.aggregator._
+import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.stat._
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ *
+ * Linear SVM Classifier
+ *
+ * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
+ * Only supports L2 regularization currently.
+ *
+ * Since 3.1.0, it supports stacking instances into blocks and using GEMV for
+ * better performance.
+ * The block size will be 1.0 MB, if param maxBlockSizeInMB is set 0.0 by default.
+ *
+ */
+@Since("2.2.0")
+class LinearSVC @Since("2.2.0") (
+ @Since("2.2.0") override val uid: String)
+ extends Classifier[Vector, LinearSVC, LinearSVCModel]
+ with LinearSVCParams with DefaultParamsWritable {
+
+ @Since("2.2.0")
+ def this() = this(Identifiable.randomUID("linearsvc"))
+
+ /**
+ * Set the regularization parameter.
+ * Default is 0.0.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Whether to fit an intercept term.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller values will lead to higher accuracy at the cost of more iterations.
+ * Default is 1E-6.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /**
+ * Whether to standardize the training features before fitting the model.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setStandardization(value: Boolean): this.type = set(standardization, value)
+
+ /**
+ * Set the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ /**
+ * Set threshold in binary classification.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setThreshold(value: Double): this.type = set(threshold, value)
+
+ /**
+ * Suggested depth for treeAggregate (greater than or equal to 2).
+ * If the dimensions of features or the number of partitions are large,
+ * this param could be adjusted to a larger size.
+ * Default is 2.
+ *
+ * @group expertSetParam
+ */
+ @Since("2.2.0")
+ def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+
+ /**
+ * Sets the value of param [[maxBlockSizeInMB]].
+ * Default is 0.0, then 1.0 MB will be chosen.
+ *
+ * @group expertSetParam
+ */
+ @Since("3.1.0")
+ def setMaxBlockSizeInMB(value: Double): this.type = set(maxBlockSizeInMB, value)
+
+ @Since("2.2.0")
+ override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
+
+ override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol,
+ regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth,
+ maxBlockSizeInMB)
+
+ if (dataset.storageLevel != StorageLevel.NONE) {
+ instr.logWarning(s"Input instances will be standardized, blockified to blocks, and " +
+ s"then cached during training. Be careful of double caching!")
+ }
+
+ val instances = extractInstances(dataset)
+ .setName("training instances")
+
+ val (summarizer, labelSummarizer) = Summarizer
+ .getClassificationSummarizers(instances, $(aggregationDepth), Seq("mean", "std", "count"))
+
+ val histogram = labelSummarizer.histogram
+ val numInvalid = labelSummarizer.countInvalid
+ val numFeatures = summarizer.mean.size
+
+ instr.logNumExamples(summarizer.count)
+ instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
+ instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
+ instr.logSumOfWeights(summarizer.weightSum)
+
+ var actualBlockSizeInMB = $(maxBlockSizeInMB)
+ if (actualBlockSizeInMB == 0) {
+ actualBlockSizeInMB = InstanceBlock.DefaultBlockSizeInMB
+ require(actualBlockSizeInMB > 0, "inferred actual BlockSizeInMB must > 0")
+ instr.logNamedValue("actualBlockSizeInMB", actualBlockSizeInMB.toString)
+ }
+
+ val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+ case Some(n: Int) =>
+ require(n >= histogram.length, s"Specified number of classes $n was " +
+ s"less than the number of unique labels ${histogram.length}.")
+ n
+ case None => histogram.length
+ }
+ require(numClasses == 2, s"LinearSVC only supports binary classification." +
+ s" $numClasses classes detected in $labelCol")
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+
+ if (numInvalid != 0) {
+ val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
+ s"Found $numInvalid invalid labels."
+ instr.logError(msg)
+ throw new SparkException(msg)
+ }
+
+ val featuresStd = summarizer.std.toArray
+ val featuresMean = summarizer.mean.toArray
+ val getFeaturesStd = (j: Int) => featuresStd(j)
+ val regularization = if ($(regParam) != 0.0) {
+ val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
+ Some(new L2Regularization($(regParam), shouldApply,
+ if ($(standardization)) None else Some(getFeaturesStd)))
+ } else None
+
+ def regParamL1Fun = (index: Int) => 0.0
+ val optimizer = new OWLQNF($(maxIter), 10, regParamL1Fun, $(tol))
+
+ /*
+ The coefficients are trained in the scaled space; we're converting them back to
+ the original space.
+ Note that the intercept in scaled space and original space is the same;
+ as a result, no scaling is needed.
+ */
+ val (rawCoefficients, objectiveHistory) =
+ trainImpl(instances, actualBlockSizeInMB, featuresStd, featuresMean,
+ regularization, optimizer)
+
+ if (rawCoefficients == null) {
+ val msg = s"${optimizer.getClass.getName} failed."
+ instr.logError(msg)
+ throw new SparkException(msg)
+ }
+
+ val coefficientArray = Array.tabulate(numFeatures) { i =>
+ if (featuresStd(i) != 0.0) rawCoefficients(i) / featuresStd(i) else 0.0
+ }
+ val intercept = if ($(fitIntercept)) rawCoefficients.last else 0.0
+ createModel(dataset, Vectors.dense(coefficientArray), intercept, objectiveHistory)
+ }
+
+ private def createModel(
+ dataset: Dataset[_],
+ coefficients: Vector,
+ intercept: Double,
+ objectiveHistory: Array[Double]): LinearSVCModel = {
+ val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
+
+ val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
+ val summary = new LinearSVCTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ rawPredictionColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ objectiveHistory)
+ model.setSummary(Some(summary))
+ }
+
+ private def trainImpl(
+ instances: RDD[Instance],
+ actualBlockSizeInMB: Double,
+ featuresStd: Array[Double],
+ featuresMean: Array[Double],
+ regularization: Option[L2Regularization],
+ optimizer: OWLQNF): (Array[Double], Array[Double]) = {
+ val numFeatures = featuresStd.length
+ val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
+
+ val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0)
+ val scaledMean = Array.tabulate(numFeatures)(i => inverseStd(i) * featuresMean(i))
+ val bcInverseStd = instances.context.broadcast(inverseStd)
+ val bcScaledMean = instances.context.broadcast(scaledMean)
+
+ val standardized = instances.mapPartitions { iter =>
+ val func = StandardScalerModel.getTransformFunc(Array.empty, bcInverseStd.value, false, true)
+ iter.map { case Instance(label, weight, vec) =>
+ Instance(label, weight + StaticUtils.ZERO_DOUBLE, func(vec)) }
+ }
+
+ val maxMemUsage = (actualBlockSizeInMB * 1024L * 1024L).ceil.toLong
+ val blocks = InstanceBlock.blokifyWithMaxMemUsage(standardized, maxMemUsage)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ .setName(s"training blocks (blockSizeInMB=$actualBlockSizeInMB)")
+
+ val getAggregatorFunc = new HingeBlockAggregator(bcInverseStd, bcScaledMean,
+ $(fitIntercept))(_)
+ val costFun = new RDDLossFunction(blocks, getAggregatorFunc,
+ regularization, $(aggregationDepth))
+
+ val initialSolution = Array.ofDim[Double](numFeaturesPlusIntercept)
+ if ($(fitIntercept)) {
+ // orginal `initialSolution` is for problem:
+ // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
+ // we should adjust it to the initial solution for problem:
+ // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
+ // NOTE: this is NOOP before we finally support model initialization
+ val adapt = BLAS.javaBLAS.ddot(numFeatures, initialSolution, 1, scaledMean, 1)
+ initialSolution(numFeatures) += adapt
+ }
+
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ new BDV[Double](initialSolution))
+ val arrayBuilder = mutable.ArrayBuilder.make[Double]
+ var state: optimizer.State = null
+ while (states.hasNext) {
+ state = states.next()
+ arrayBuilder += state.adjustedValue
+ }
+ blocks.unpersist()
+ bcInverseStd.destroy()
+ bcScaledMean.destroy()
+
+ val solution = if (state == null) null else state.x.toArray
+ if ($(fitIntercept) && solution != null) {
+ // the final solution is for problem:
+ // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept)
+ // we should adjust it back for original problem:
+ // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept)
+ val adapt = BLAS.javaBLAS.ddot(numFeatures, solution, 1, scaledMean, 1)
+ solution(numFeatures) -= adapt
+ }
+ (solution, arrayBuilder.result)
+ }
+}
+
+@Since("2.2.0")
+object LinearSVC extends DefaultParamsReadable[LinearSVC] {
+
+ @Since("2.2.0")
+ override def load(path: String): LinearSVC = super.load(path)
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/FeatureEncoding.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/FeatureEncoding.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b44860d900748af90dbbc41479b8cd75de025358
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/FeatureEncoding.scala
@@ -0,0 +1,155 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.feature
+
+import java.io.File
+
+import scala.collection.{mutable, JavaConverters}
+import scala.collection.mutable.ArrayBuffer
+import scala.io.Source
+
+import com.fasterxml.jackson.databind.ObjectMapper
+
+import org.apache.spark.ml.StaticUtils
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.functions.{col, lit, udf}
+
+class FeatureEncoding extends Serializable{
+ var mapLoadPath = ""
+ var dataPath = ""
+ var outputFilePath = ""
+ var localSavePath = ""
+ var encodeColumns = Array[String]()
+ var numThread = 40
+
+ def setMapLoadPath(mapLoadPath: String): this.type = {
+ this.mapLoadPath = mapLoadPath
+ this
+ }
+
+ def setDataPath(dataPath: String): this.type = {
+ this.dataPath = dataPath
+ this
+ }
+
+ def setOutputFilePath(outputFilePath: String): this.type = {
+ this.outputFilePath = outputFilePath
+ this
+ }
+
+ def setLocalSavePath(localSavePath: String): this.type = {
+ this.localSavePath = localSavePath
+ this
+ }
+
+ def setEncodeColumns(encodeColumns: String): this.type = {
+ this.encodeColumns = encodeColumns.split(",")
+ this
+ }
+
+ def setNumThread(numThread: Int): this.type = {
+ this.numThread = numThread
+ this
+ }
+
+ def parseJsonToIntMap(json: String): mutable.Map[String, Int] = {
+ val mapper = new ObjectMapper()
+ val node = mapper.readValue(json, classOf[java.util.HashMap[String, Int]])
+ JavaConverters.mapAsScalaMap(node)
+ }
+
+ def loadJsonToString(path: String): String = {
+ Source.fromFile(path, "utf-8").mkString
+ }
+
+ def padZero(input: Array[Int], maxLength: Int): Array[Int] = {
+ if (input.length > maxLength) {
+ input.dropRight(input.length-maxLength)
+ } else {
+ input.++(Array.ofDim[Int](maxLength-input.length))
+ }
+ }
+
+ def transform(input: DataFrame, featureMapKey: String,
+ featureMap: Map[String, Int], inputCol: String*): DataFrame = {
+ if (featureMap.isEmpty) {
+ throw new Exception("featureMap is empty")
+ }
+
+ val suffixName = "_index"
+ val transformUDF = udf((maxLengthKey: String, value: String) => {
+ val transformList = ArrayBuffer[Int]()
+ if (featureMap.contains(featureMapKey + "," + value)) {
+ transformList.append(featureMap(featureMapKey + "," + value))
+ } else {
+ // use 1 as feature index if not found
+ transformList.append(1)
+ }
+
+ // return the maxLength array
+ padZero(transformList.toArray, StaticUtils.ONE_INT)
+ })
+
+ var data = input
+ for (cols <- inputCol) {
+ data = data.withColumn(
+ cols + suffixName,
+ transformUDF(
+ lit(cols),
+ col(cols)
+ ))
+ }
+ data
+ }
+
+ def dirDel(path: File) {
+ if (!path.exists()) {
+ return
+ }
+ if (path.isFile()) {
+ path.delete()
+ return
+ }
+ val file: Array[File] = path.listFiles()
+ for (d <- file) {
+ dirDel(d)
+ }
+ path.delete()
+ }
+
+ def copyFileToLocal(spark: SparkSession, hdfsPath: String, localPath: String): Unit = {
+ val localFilePath = new File(localPath)
+ dirDel(localFilePath)
+ if (!localFilePath.exists()) {
+ localFilePath.mkdirs()
+ }
+ EncoderUtils.save2PathPar(hdfsPath, localPath, numThread)
+ }
+
+ def execute(dataset: DataFrame = null): Unit = {
+ require(mapLoadPath.nonEmpty, "mapLoadPath is empty")
+ require(dataPath.nonEmpty, "dataPath is empty")
+ require(outputFilePath.nonEmpty, "outputFilePath is empty")
+ require(localSavePath.nonEmpty, "localSavePath is empty")
+ require(numThread > 0, "numThread is illegal")
+ val featureMap = parseJsonToIntMap(loadJsonToString(mapLoadPath))
+ var res = dataset
+ for(feature <- encodeColumns) {
+ require(res.columns.contains(feature), "non existent encodeColumns: " + feature)
+ res = transform(res, feature, featureMap.toMap, feature)
+ }
+ res
+ .select(encodeColumns.map{t => col(t + "_index")}: _*)
+ .write.mode("overwrite")
+ .save(outputFilePath)
+
+ copyFileToLocal(res.sparkSession, outputFilePath, localSavePath + "encode")
+ copyFileToLocal(res.sparkSession, dataPath, localSavePath + "data")
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/IDF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e451d4daffbc75503f9d9355b84637d8a4a84364
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml._
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.VersionUtils.majorVersion
+
+/**
+ * Params for [[IDF]] and [[IDFModel]].
+ */
+private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * The minimum number of documents in which a term should appear.
+ * Default: 0
+ * @group param
+ */
+ final val minDocFreq = new IntParam(
+ this, "minDocFreq", "minimum number of documents in which a term should appear for filtering" +
+ " (>= 0)", ParamValidators.gtEq(0))
+
+ setDefault(minDocFreq -> 0)
+
+ /** @group getParam */
+ def getMinDocFreq: Int = $(minDocFreq)
+
+ /**
+ * Validate and transform the input schema.
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ }
+}
+
+/**
+ * Compute the Inverse Document Frequency (IDF) given a collection of documents.
+ */
+@Since("1.4.0")
+final class IDF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+ extends Estimator[IDFModel] with IDFBase with DefaultParamsWritable {
+
+ @Since("1.4.0")
+ def this() = this(Identifiable.randomUID("idf"))
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IDFModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
+ case Row(v: Vector) => OldVectors.fromML(v)
+ }
+ val idf = new feature.IDF($(minDocFreq)).fit(input)
+ copyValues(new IDFModel(uid, idf).setParent(this))
+ }
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): IDF = defaultCopy(extra)
+}
+
+@Since("1.6.0")
+object IDF extends DefaultParamsReadable[IDF] {
+
+ @Since("1.6.0")
+ override def load(path: String): IDF = super.load(path)
+}
+
+/**
+ * Model fitted by [[IDF]].
+ */
+@Since("1.4.0")
+class IDFModel private[ml] (
+ @Since("1.4.0") override val uid: String,
+ idfModel: feature.IDFModel)
+ extends Model[IDFModel] with IDFBase with MLWritable {
+
+ import IDFModel._
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ @Since("1.4.0")
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
+ val func = { vector: Vector =>
+ vector match {
+ case SparseVector(size, indices, values) =>
+ val (newIndices, newValues) = feature.IDFModel.transformSparse(idfModel.idf,
+ indices, values)
+ Vectors.sparse(size, newIndices, newValues)
+ case DenseVector(values) =>
+ val newValues = feature.IDFModel.transformDense(idfModel.idf, values)
+ Vectors.dense(newValues)
+ case other =>
+ throw new UnsupportedOperationException(
+ s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+ }
+ }
+
+ val transformer = udf(func)
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
+ }
+
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), idf.size)
+ }
+ outputSchema
+ }
+
+ @Since("1.4.1")
+ override def copy(extra: ParamMap): IDFModel = {
+ val copied = new IDFModel(uid, idfModel)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ /** Returns the IDF vector. */
+ @Since("2.0.0")
+ def idf: Vector = idfModel.idf.asML
+
+ /** Returns the document frequency */
+ @Since("3.0.0")
+ def docFreq: Array[Long] = idfModel.docFreq
+
+ /** Returns number of documents evaluated to compute idf */
+ @Since("3.0.0")
+ def numDocs: Long = idfModel.numDocs
+
+ @Since("1.6.0")
+ override def write: MLWriter = new IDFModelWriter(this)
+
+ @Since("3.0.0")
+ override def toString: String = {
+ s"IDFModel: uid=$uid, numDocs=$numDocs, numFeatures=${idf.size}"
+ }
+}
+
+@Since("1.6.0")
+object IDFModel extends MLReadable[IDFModel] {
+
+ private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter {
+
+ private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.idf, instance.docFreq, instance.numDocs)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class IDFModelReader extends MLReader[IDFModel] {
+
+ private val className = classOf[IDFModel].getName
+
+ override def load(path: String): IDFModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.parquet(dataPath)
+
+ val model = if (majorVersion(metadata.sparkVersion) >= 3) {
+ val Row(idf: Vector, df: scala.collection.Seq[_], numDocs: Long) =
+ data.select("idf", "docFreq", "numDocs").head()
+ new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf),
+ df.asInstanceOf[scala.collection.Seq[Long]].toArray, numDocs))
+ } else {
+ val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf")
+ .select("idf")
+ .head()
+ new IDFModel(metadata.uid,
+ new feature.IDFModel(OldVectors.fromML(idf), new Array[Long](idf.size), 0L))
+ }
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[IDFModel] = new IDFModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): IDFModel = super.load(path)
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index e3c4bae3629275d31cc7b52a5449bc8cb377e7f9..5ffc0accbbc3af13f28cc41e82d889109deb37de 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.ml.recommendation
import java.{util => ju}
@@ -28,11 +27,12 @@ import scala.util.hashing.byteswap64
import breeze.linalg.blas.YTYUtils.compute
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import com.google.common.collect.{Ordering => GuavaOrdering}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
-import org.apache.spark.{Partitioner, SparkException}
+import org.apache.spark.Partitioner
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
@@ -43,12 +43,12 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
-import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, Utils}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom
@@ -270,10 +270,10 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
*/
@Since("1.3.0")
class ALSModel private[ml] (
- @Since("1.4.0") override val uid: String,
- @Since("1.4.0") val rank: Int,
- @transient val userFactors: DataFrame,
- @transient val itemFactors: DataFrame)
+ @Since("1.4.0") override val uid: String,
+ @Since("1.4.0") val rank: Int,
+ @transient val userFactors: DataFrame,
+ @transient val itemFactors: DataFrame)
extends Model[ALSModel] with ALSModelParams with MLWritable {
/** @group setParam */
@@ -418,9 +418,9 @@ class ALSModel private[ml] (
* the factor DataFrame.
*/
private def getSourceFactorSubset(
- dataset: Dataset[_],
- factors: DataFrame,
- column: String): DataFrame = {
+ dataset: Dataset[_],
+ factors: DataFrame,
+ column: String): DataFrame = {
factors
.join(dataset.select(column), factors("id") === dataset(column), joinType = "left_semi")
.select(factors("id"), factors("features"))
@@ -435,7 +435,8 @@ class ALSModel private[ml] (
* relatively efficient, the approach implemented here is significantly more efficient.
*
* This approach groups factors into blocks and computes the top-k elements per block,
- * using dot product and an efficient [[BoundedPriorityQueue]] (instead of gemm).
+ * using GEMV (it use less memory compared with GEMM, and is much faster than DOT) and
+ * an efficient selection based on [[GuavaOrdering]] (instead of [[BoundedPriorityQueue]]).
* It then computes the global top-k by aggregating the per block top-k elements with
* a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data.
* This is the DataFrame equivalent to the approach used in
@@ -450,37 +451,46 @@ class ALSModel private[ml] (
* stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
*/
private def recommendForAll(
- srcFactors: DataFrame,
- dstFactors: DataFrame,
- srcOutputColumn: String,
- dstOutputColumn: String,
- num: Int,
- blockSize: Int): DataFrame = {
+ srcFactors: DataFrame,
+ dstFactors: DataFrame,
+ srcOutputColumn: String,
+ dstOutputColumn: String,
+ num: Int,
+ blockSize: Int): DataFrame = {
import srcFactors.sparkSession.implicits._
+ import scala.collection.JavaConverters._
val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
- .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
- .flatMap { case (srcIter, dstIter) =>
- val m = srcIter.size
- val n = math.min(dstIter.size, num)
- val output = new Array[(Int, Int, Float)](m * n)
- var i = 0
- val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
- srcIter.foreach { case (srcId, srcFactor) =>
- dstIter.foreach { case (dstId, dstFactor) =>
- // We use F2jBLAS which is faster than a call to native BLAS for vector dot product
- val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
- pq += dstId -> score
+ .as[(Array[Int], Array[Float], Array[Int], Array[Float])]
+ .mapPartitions { iter =>
+ var scores: Array[Float] = null
+ var idxOrd: GuavaOrdering[Int] = null
+ iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
+ require(srcMat.length == srcIds.length * rank)
+ require(dstMat.length == dstIds.length * rank)
+ val m = srcIds.length
+ val n = dstIds.length
+ if (scores == null || scores.length < n) {
+ scores = Array.ofDim[Float](n)
+ idxOrd = new GuavaOrdering[Int] {
+ override def compare(left: Int, right: Int): Int = {
+ Ordering[Float].compare(scores(left), scores(right))
+ }
+ }
}
- pq.foreach { case (dstId, score) =>
- output(i) = (srcId, dstId, score)
- i += 1
+
+ Iterator.range(0, m).flatMap { i =>
+ // scores = i-th vec in srcMat * dstMat
+ BLAS.javaBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
+ srcMat, i * rank, 1, 0.0F, scores, 0, 1)
+
+ val srcId = srcIds(i)
+ idxOrd.greatestOf(Iterator.range(0, n).asJava, num).asScala
+ .iterator.map { j => (srcId, dstIds(j), scores(j)) }
}
- pq.clear()
}
- output.toSeq
}
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
@@ -499,10 +509,13 @@ class ALSModel private[ml] (
* Blockifies factors to improve the efficiency of cross join
*/
private def blockify(
- factors: Dataset[(Int, Array[Float])],
- blockSize: Int): Dataset[Seq[(Int, Array[Float])]] = {
+ factors: Dataset[(Int, Array[Float])],
+ blockSize: Int): Dataset[(Array[Int], Array[Float])] = {
import factors.sparkSession.implicits._
- factors.mapPartitions(_.grouped(blockSize))
+ factors.mapPartitions { iter =>
+ iter.grouped(blockSize)
+ .map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray))
+ }
}
}
@@ -878,7 +891,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
def copyATA(source: Array[Double]): this.type = {
- blas.daxpy(ata.length, 1.0, source, 1, ata, 1)
+ BLAS.nativeBLAS.daxpy(ata.length, 1.0, source, 1, ata, 1)
this
}
@@ -938,20 +951,20 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* submatrix of the ratings matrix.
*/
def train[ID: ClassTag]( // scalastyle:ignore
- ratings: RDD[Rating[ID]],
- rank: Int = 10,
- numUserBlocks: Int = 10,
- numItemBlocks: Int = 10,
- maxIter: Int = 10,
- regParam: Double = 0.1,
- implicitPrefs: Boolean = false,
- alpha: Double = 1.0,
- nonnegative: Boolean = false,
- intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
- finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
- checkpointInterval: Int = 10,
- seed: Long = 0L)(
- implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
+ ratings: RDD[Rating[ID]],
+ rank: Int = 10,
+ numUserBlocks: Int = 10,
+ numItemBlocks: Int = 10,
+ maxIter: Int = 10,
+ regParam: Double = 0.1,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0,
+ nonnegative: Boolean = false,
+ intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ checkpointInterval: Int = 10,
+ seed: Long = 0L)(
+ implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(!ratings.isEmpty(), s"No ratings available from $ratings")
require(intermediateRDDStorageLevel != StorageLevel.NONE,
@@ -1286,10 +1299,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @see [[LocalIndexEncoder]]
*/
private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
- srcIds: Array[ID],
- dstPtrs: Array[Int],
- dstEncodedIndices: Array[Int],
- ratings: Array[Float]) {
+ srcIds: Array[ID],
+ dstPtrs: Array[Int],
+ dstEncodedIndices: Array[Int],
+ ratings: Array[Float]) {
/** Size of the block. */
def size: Int = ratings.length
require(dstEncodedIndices.length == size)
@@ -1304,12 +1317,11 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @return initialized factor blocks
*/
private def initialize[ID](
- inBlocks: RDD[(Int, InBlock[ID])],
- rank: Int,
- seed: Long): RDD[(Int, FactorBlock)] = {
- // Choose a unit vector uniformly at random from the unit sphere, but from the
- // "first quadrant" where all elements are nonnegative. This can be done by choosing
- // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
+ inBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int,
+ seed: Long): RDD[(Int, FactorBlock)] = {
+ // Choose a unit vector uniformly at random from the unit sphere. This can be done by choosing
+ // elements distributed as Normal(0,1), and then normalizing.
// This appears to create factorizations that have a slightly better reconstruction
// (<1%) compared picking elements uniformly at random in [0,1].
inBlocks.mapPartitions({ iter =>
@@ -1331,9 +1343,9 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
*/
private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag](
- srcIds: Array[ID],
- dstIds: Array[ID],
- ratings: Array[Float]) {
+ srcIds: Array[ID],
+ dstIds: Array[ID],
+ ratings: Array[Float]) {
/** Size of the block. */
def size: Int = srcIds.length
require(dstIds.length == srcIds.length)
@@ -1401,9 +1413,9 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
*/
private def partitionRatings[ID: ClassTag](
- ratings: RDD[Rating[ID]],
- srcPart: Partitioner,
- dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = {
+ ratings: RDD[Rating[ID]],
+ srcPart: Partitioner,
+ dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = {
val numPartitions = srcPart.numPartitions * dstPart.numPartitions
ratings.mapPartitions { iter =>
val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID])
@@ -1439,8 +1451,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @param encoder encoder for dst indices
*/
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
- encoder: LocalIndexEncoder)(
- implicit ord: Ordering[ID]) {
+ encoder: LocalIndexEncoder)(
+ implicit ord: Ordering[ID]) {
private val srcIds = mutable.ArrayBuilder.make[ID]
private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
@@ -1455,10 +1467,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @param ratings ratings
*/
def add(
- dstBlockId: Int,
- srcIds: Array[ID],
- dstLocalIndices: Array[Int],
- ratings: Array[Float]): this.type = {
+ dstBlockId: Int,
+ srcIds: Array[ID],
+ dstLocalIndices: Array[Int],
+ ratings: Array[Float]): this.type = {
val sz = srcIds.length
require(dstLocalIndices.length == sz)
require(ratings.length == sz)
@@ -1482,10 +1494,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
*/
private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag](
- val srcIds: Array[ID],
- val dstEncodedIndices: Array[Int],
- val ratings: Array[Float])(
- implicit ord: Ordering[ID]) {
+ val srcIds: Array[ID],
+ val dstEncodedIndices: Array[Int],
+ val ratings: Array[Float])(
+ implicit ord: Ordering[ID]) {
/** Size the of block. */
def length: Int = srcIds.length
@@ -1550,7 +1562,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @see [[UncompressedInBlockSort]]
*/
private class KeyWrapper[@specialized(Int, Long) ID: ClassTag](
- implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] {
+ implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] {
var key: ID = _
@@ -1568,15 +1580,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
*/
private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag](
- implicit ord: Ordering[ID])
+ implicit ord: Ordering[ID])
extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] {
override def newKey(): KeyWrapper[ID] = new KeyWrapper()
override def getKey(
- data: UncompressedInBlock[ID],
- pos: Int,
- reuse: KeyWrapper[ID]): KeyWrapper[ID] = {
+ data: UncompressedInBlock[ID],
+ pos: Int,
+ reuse: KeyWrapper[ID]): KeyWrapper[ID] = {
if (reuse == null) {
new KeyWrapper().setKey(data.srcIds(pos))
} else {
@@ -1585,15 +1597,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
override def getKey(
- data: UncompressedInBlock[ID],
- pos: Int): KeyWrapper[ID] = {
+ data: UncompressedInBlock[ID],
+ pos: Int): KeyWrapper[ID] = {
getKey(data, pos, null)
}
private def swapElements[@specialized(Int, Float) T](
- data: Array[T],
- pos0: Int,
- pos1: Int): Unit = {
+ data: Array[T],
+ pos0: Int,
+ pos1: Int): Unit = {
val tmp = data(pos0)
data(pos0) = data(pos1)
data(pos1) = tmp
@@ -1606,11 +1618,11 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
override def copyRange(
- src: UncompressedInBlock[ID],
- srcPos: Int,
- dst: UncompressedInBlock[ID],
- dstPos: Int,
- length: Int): Unit = {
+ src: UncompressedInBlock[ID],
+ srcPos: Int,
+ dst: UncompressedInBlock[ID],
+ dstPos: Int,
+ length: Int): Unit = {
System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
@@ -1622,10 +1634,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
override def copyElement(
- src: UncompressedInBlock[ID],
- srcPos: Int,
- dst: UncompressedInBlock[ID],
- dstPos: Int): Unit = {
+ src: UncompressedInBlock[ID],
+ srcPos: Int,
+ dst: UncompressedInBlock[ID],
+ dstPos: Int): Unit = {
dst.srcIds(dstPos) = src.srcIds(srcPos)
dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
dst.ratings(dstPos) = src.ratings(srcPos)
@@ -1642,12 +1654,12 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @return (in-blocks, out-blocks)
*/
private def makeBlocks[ID: ClassTag](
- prefix: String,
- ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
- srcPart: Partitioner,
- dstPart: Partitioner,
- storageLevel: StorageLevel)(
- implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
+ prefix: String,
+ ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
+ srcPart: Partitioner,
+ dstPart: Partitioner,
+ storageLevel: StorageLevel)(
+ implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
val inBlocks = ratingBlocks.map {
case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
// The implementation is a faster version of
@@ -1712,16 +1724,16 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
def computeFactorsNew[ID](
- srcFactorBlocks: RDD[(Int, FactorBlock)],
- dpl: Int,
- r: RDD[(Int, (Array[(Int, Array[Int])], InBlock[ID]))],
- rank: Int,
- regParam: Double,
- srcEncoder: LocalIndexEncoder,
- implicitPrefs: Boolean = false,
- alpha: Double = 1.0,
- solver: LeastSquaresNESolver,
- blockMaxRow: Int): RDD[(Int, FactorBlock)] = {
+ srcFactorBlocks: RDD[(Int, FactorBlock)],
+ dpl: Int,
+ r: RDD[(Int, (Array[(Int, Array[Int])], InBlock[ID]))],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: LocalIndexEncoder,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0,
+ solver: LeastSquaresNESolver,
+ blockMaxRow: Int): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val srcFactorMap = srcFactorBlocks.collectAsMap()
@@ -1785,8 +1797,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* Caching of the input factors is handled in [[ALS#train]].
*/
private def computeYtY(factorBlocks: Array[FactorBlock],
- rank: Int,
- blockMaxRow: Int): NormalEquation = {
+ rank: Int,
+ blockMaxRow: Int): NormalEquation = {
val ne = new NormalEquation(rank)
val triK = rank * (rank + 1) / 2
val c = Array.fill(triK)(0.0f)
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/NMF.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/NMF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2739fe2d200d8f3dd8b43999d5594ca5f9606529
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/recommendation/NMF.scala
@@ -0,0 +1,1692 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.{util => ju}
+import java.io.IOException
+import java.util.Locale
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.{Sorting, Try}
+import scala.util.hashing.byteswap64
+
+import com.google.common.collect.{Ordering => GuavaOrdering}
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+
+import org.apache.spark.{Partitioner, SparkContext, SparkException}
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.linalg.BLAS
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.optimization.NNLS
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Common params for NMF and NMFModel.
+ */
+private[recommendation] trait NMFModelParams extends Params with HasPredictionCol
+ with HasBlockSize {
+ /**
+ * Param for the column name for user ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
+ * Default: "user"
+ *
+ * @group param
+ */
+ val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
+ "the integer value range.")
+
+ /** @group getParam */
+ def getUserCol: String = $(userCol)
+
+ /**
+ * Param for the column name for item ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
+ * Default: "item"
+ *
+ * @group param
+ */
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
+ "the integer value range.")
+
+ /** @group getParam */
+ def getItemCol: String = $(itemCol)
+
+ /**
+ * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
+ * out of integer range or contains a fractional part.
+ */
+ protected[recommendation] val checkedCast = udf { (n: Any) =>
+ n match {
+ case v: Int => v // Avoid unnecessary casting
+ case v: Number =>
+ val intV = v.intValue
+ // Checks if number within Int range and has no fractional part.
+ if (v.doubleValue == intV) {
+ intV
+ } else {
+ throw new IllegalArgumentException(s"NMF only supports values in Integer range " +
+ s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " +
+ s"Value $n was either out of Integer range or contained a fractional part that " +
+ s"could not be converted.")
+ }
+ case _ => throw new IllegalArgumentException(s"NMF only supports values in Integer range " +
+ s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.")
+ }
+ }
+
+ /**
+ * Param for strategy for dealing with unknown or new users/items at prediction time.
+ * This may be useful in cross-validation or production scenarios, for handling user/item ids
+ * the model has not seen in the training data.
+ * Supported values:
+ * - "nan": predicted value for unknown ids will be NaN.
+ * - "drop": rows in the input DataFrame containing unknown ids will be dropped from
+ * the output DataFrame containing predictions.
+ * Default: "nan".
+ *
+ * @group expertParam
+ */
+ val coldStartStrategy = new Param[String](this, "coldStartStrategy",
+ "strategy for dealing with unknown or new users/items at prediction time. This may be " +
+ "useful in cross-validation or production scenarios, for handling user/item ids the model " +
+ "has not seen in the training data. Supported values: " +
+ s"${NMFModel.supportedColdStartStrategies.mkString(",")}.",
+ (s: String) =>
+ NMFModel.supportedColdStartStrategies.contains(s.toLowerCase(Locale.ROOT)))
+
+ /** @group expertGetParam */
+ def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase(Locale.ROOT)
+
+ setDefault(blockSize -> 4096)
+}
+
+/**
+ * Common params for NMF.
+ */
+private[recommendation] trait NMFParams extends NMFModelParams with HasMaxIter with HasRegParam
+ with HasCheckpointInterval with HasSeed {
+
+ /**
+ * Param for rank of the matrix factorization (positive).
+ * Default: 10
+ * @group param
+ */
+ val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1))
+
+ /** @group getParam */
+ def getRank: Int = $(rank)
+
+ /**
+ * Param for number of user blocks (positive).
+ * Default: 10
+ * @group param
+ */
+ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks",
+ ParamValidators.gtEq(1))
+
+ /** @group getParam */
+ def getNumUserBlocks: Int = $(numUserBlocks)
+
+ /**
+ * Param for number of item blocks (positive).
+ * Default: 10
+ * @group param
+ */
+ val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks",
+ ParamValidators.gtEq(1))
+
+ /** @group getParam */
+ def getNumItemBlocks: Int = $(numItemBlocks)
+
+ /**
+ * Param for the column name for ratings.
+ * Default: "rating"
+ *
+ * @group param
+ */
+ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings")
+
+ /** @group getParam */
+ def getRatingCol: String = $(ratingCol)
+
+ /**
+ * Param for StorageLevel for intermediate datasets. Pass in a string representation of
+ * `StorageLevel`. Cannot be "NONE".
+ * Default: "MEMORY_AND_DISK".
+ *
+ * @group expertParam
+ */
+ val intermediateStorageLevel = new Param[String](this, "intermediateStorageLevel",
+ "StorageLevel for intermediate datasets. Cannot be 'NONE'.",
+ (s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE")
+
+ /** @group expertGetParam */
+ def getIntermediateStorageLevel: String = $(intermediateStorageLevel)
+
+ /**
+ * Param for StorageLevel for NMF model factors. Pass in a string representation of
+ * `StorageLevel`.
+ * Default: "MEMORY_AND_DISK".
+ *
+ * @group expertParam
+ */
+ val finalStorageLevel = new Param[String](this, "finalStorageLevel",
+ "StorageLevel for NMF model factors.",
+ (s: String) => Try(StorageLevel.fromString(s)).isSuccess)
+
+ /** @group expertGetParam */
+ def getFinalStorageLevel: String = $(finalStorageLevel)
+
+ setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
+ userCol -> "user", itemCol -> "item", ratingCol -> "rating", checkpointInterval -> 10,
+ intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
+ coldStartStrategy -> "nan")
+
+ /**
+ * Validates and transforms the input schema.
+ *
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
+ // rating will be cast to Float
+ SchemaUtils.checkNumericType(schema, $(ratingCol))
+ SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
+ }
+}
+
+/**
+ * Model fitted by NMF.
+ *
+ * @param rank rank of the matrix factorization model
+ * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
+ * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
+ */
+class NMFModel private[ml](
+ override val uid: String,
+ val rank: Int,
+ @transient val userFactors: DataFrame,
+ @transient val itemFactors: DataFrame)
+ extends Model[NMFModel] with NMFModelParams with MLWritable {
+
+ /** @group setParam */
+ def setUserCol(value: String): this.type = set(userCol, value)
+
+ /** @group setParam */
+ def setItemCol(value: String): this.type = set(itemCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group expertSetParam */
+ def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
+ /**
+ * Set block size for stacking input data in matrices.
+ * Default is 4096.
+ *
+ * @group expertSetParam
+ */
+ def setBlockSize(value: Int): this.type = set(blockSize, value)
+
+ private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) =>
+ if (featuresA != null && featuresB != null) {
+ var dotProduct = 0.0f
+ var i = 0
+ while (i < rank) {
+ dotProduct += featuresA(i) * featuresB(i)
+ i += 1
+ }
+ dotProduct
+ } else {
+ Float.NaN
+ }
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema)
+ // create a new column named map(predictionCol) by running the predict UDF.
+ val predictions = dataset
+ .join(userFactors,
+ checkedCast(dataset($(userCol))) === userFactors("id"), "left")
+ .join(itemFactors,
+ checkedCast(dataset($(itemCol))) === itemFactors("id"), "left")
+ .select(dataset("*"),
+ predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
+ getColdStartStrategy match {
+ case NMFModel.Drop =>
+ predictions.na.drop("all", Seq($(predictionCol)))
+ case NMFModel.NaN =>
+ predictions
+ }
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
+ SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
+ }
+
+ override def copy(extra: ParamMap): NMFModel = {
+ val copied = new NMFModel(uid, rank, userFactors, itemFactors)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ override def write: MLWriter = new NMFModel.NMFModelWriter(this)
+
+ override def toString: String = {
+ s"NMFModel: uid=$uid, rank=$rank"
+ }
+
+ /**
+ * Returns top `numItems` items recommended for each user, for all users.
+ * @param numItems max number of recommendations for each user
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ def recommendForAllUsers(numItems: Int): DataFrame = {
+ recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems, $(blockSize))
+ }
+
+ /**
+ * Returns top `numItems` items recommended for each user id in the input data set. Note that if
+ * there are duplicate ids in the input dataset, only one set of recommendations per unique id
+ * will be returned.
+ *
+ * @param dataset a Dataset containing a column of user ids.
+ * The column name must match `userCol`.
+ * @param numItems max number of recommendations for each user.
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = {
+ val srcFactorSubset = getSourceFactorSubset(dataset, userFactors, $(userCol))
+ recommendForAll(srcFactorSubset, itemFactors, $(userCol), $(itemCol), numItems, $(blockSize))
+ }
+
+ /**
+ * Returns top `numUsers` users recommended for each item, for all items.
+ * @param numUsers max number of recommendations for each item
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ def recommendForAllItems(numUsers: Int): DataFrame = {
+ recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers, $(blockSize))
+ }
+
+ /**
+ * Returns top `numUsers` users recommended for each item id in the input data set. Note that if
+ * there are duplicate ids in the input dataset, only one set of recommendations per unique id
+ * will be returned.
+ *
+ * @param dataset a Dataset containing a column of item ids.
+ * The column name must match `itemCol`.
+ * @param numUsers max number of recommendations for each item.
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = {
+ val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, $(itemCol))
+ recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), numUsers, $(blockSize))
+ }
+
+ /**
+ * Returns a subset of a factor DataFrame limited to only those unique ids contained
+ * in the input dataset.
+ * @param dataset input Dataset containing id column to user to filter factors.
+ * @param factors factor DataFrame to filter.
+ * @param column column name containing the ids in the input dataset.
+ * @return DataFrame containing factors only for those ids present in both the input dataset and
+ * the factor DataFrame.
+ */
+ private def getSourceFactorSubset(
+ dataset: Dataset[_],
+ factors: DataFrame,
+ column: String): DataFrame = {
+ factors
+ .join(dataset.select(column), factors("id") === dataset(column), joinType = "left_semi")
+ .select(factors("id"), factors("features"))
+ }
+
+ /**
+ * Makes recommendations for all users (or items).
+ *
+ * Note: the previous approach used for computing top-k recommendations
+ * used a cross-join followed by predicting a score for each row of the joined dataset.
+ * However, this results in exploding the size of intermediate data. While Spark SQL makes it
+ * relatively efficient, the approach implemented here is significantly more efficient.
+ *
+ * This approach groups factors into blocks and computes the top-k elements per block,
+ * using GEMV (it use less memory compared with GEMM, and is much faster than DOT) and
+ * an efficient selection based on [[GuavaOrdering]] (instead of [[BoundedPriorityQueue]]).
+ * It then computes the global top-k by aggregating the per block top-k elements with
+ * a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data.
+ * This is the DataFrame equivalent to the approach used in
+ * [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]].
+ *
+ * @param srcFactors src factors for which to generate recommendations
+ * @param dstFactors dst factors used to make recommendations
+ * @param srcOutputColumn name of the column for the source ID in the output DataFrame
+ * @param dstOutputColumn name of the column for the destination ID in the output DataFrame
+ * @param num max number of recommendations for each record
+ * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
+ * stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
+ */
+ private def recommendForAll(
+ srcFactors: DataFrame,
+ dstFactors: DataFrame,
+ srcOutputColumn: String,
+ dstOutputColumn: String,
+ num: Int,
+ blockSize: Int): DataFrame = {
+ import srcFactors.sparkSession.implicits._
+ import scala.collection.JavaConverters._
+
+ val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
+ val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
+ val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
+ .as[(Array[Int], Array[Float], Array[Int], Array[Float])]
+ .mapPartitions { iter =>
+ var scores: Array[Float] = null
+ var idxOrd: GuavaOrdering[Int] = null
+ iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
+ require(srcMat.length == srcIds.length * rank)
+ require(dstMat.length == dstIds.length * rank)
+ val m = srcIds.length
+ val n = dstIds.length
+ if (scores == null || scores.length < n) {
+ scores = Array.ofDim[Float](n)
+ idxOrd = new GuavaOrdering[Int] {
+ override def compare(left: Int, right: Int): Int = {
+ Ordering[Float].compare(scores(left), scores(right))
+ }
+ }
+ }
+
+ Iterator.range(0, m).flatMap { i =>
+ // scores = i-th vec in srcMat * dstMat
+ BLAS.javaBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
+ srcMat, i * rank, 1, 0.0F, scores, 0, 1)
+
+ val srcId = srcIds(i)
+ idxOrd.greatestOf(Iterator.range(0, n).asJava, num).asScala
+ .iterator.map { j => (srcId, dstIds(j), scores(j)) }
+ }
+ }
+ }
+ // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
+ val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
+ val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
+ .toDF("id", "recommendations")
+
+ val arrayType = ArrayType(
+ new StructType()
+ .add(dstOutputColumn, IntegerType)
+ .add("rating", FloatType)
+ )
+ recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType))
+ }
+
+ /**
+ * Blockifies factors to improve the efficiency of cross join
+ */
+ private def blockify(
+ factors: Dataset[(Int, Array[Float])],
+ blockSize: Int): Dataset[(Array[Int], Array[Float])] = {
+ import factors.sparkSession.implicits._
+ factors.mapPartitions { iter =>
+ iter.grouped(blockSize)
+ .map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray))
+ }
+ }
+
+}
+
+object NMFModel extends MLReadable[NMFModel] {
+
+ private val NaN = "nan"
+ private val Drop = "drop"
+ private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop)
+
+ override def read: MLReader[NMFModel] = new NMFModelReader
+
+ override def load(path: String): NMFModel = super.load(path)
+
+ private[NMFModel] class NMFModelWriter(instance: NMFModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata = "rank" -> instance.rank
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val userPath = new Path(path, "userFactors").toString
+ instance.userFactors.write.format("parquet").save(userPath)
+ val itemPath = new Path(path, "itemFactors").toString
+ instance.itemFactors.write.format("parquet").save(itemPath)
+ }
+ }
+
+ private class NMFModelReader extends MLReader[NMFModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[NMFModel].getName
+
+ override def load(path: String): NMFModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ implicit val format = DefaultFormats
+ val rank = (metadata.metadata \ "rank").extract[Int]
+ val userPath = new Path(path, "userFactors").toString
+ val userFactors = sparkSession.read.format("parquet").load(userPath)
+ val itemPath = new Path(path, "itemFactors").toString
+ val itemFactors = sparkSession.read.format("parquet").load(itemPath)
+
+ val model = new NMFModel(metadata.uid, rank, userFactors, itemFactors)
+
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+}
+
+/**
+ * Alternating Least Squares (NMF) matrix factorization.
+ *
+ * NMF attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
+ * This is a blocked implementation of the NMF factorization algorithm that groups the two sets
+ * of factors (referred to as "users" and "products") into blocks and reduces communication by only
+ * sending one copy of each user vector to each product block on each iteration, and only for the
+ * product blocks that need that user's feature vector. This is achieved by pre-computing some
+ * information about the ratings matrix to determine the "out-links" of each user (which blocks of
+ * products it will contribute to) and "in-link" information for each product (which of the feature
+ * vectors it receives from each user block it will depend on). This allows us to send only an
+ * array of feature vectors between each user block and product block, and have the product block
+ * find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * https://doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
+ * r is greater than 0 and 0 if r is less than or equal to 0. The ratings then act as 'confidence'
+ * values related to strength of indicated user
+ * preferences rather than explicit ratings given to items.
+ *
+ * Note: the input rating dataset to the NMF implementation should be deterministic.
+ * Nondeterministic data can cause failure during fitting NMF model.
+ * For example, an order-sensitive operation like sampling after a repartition makes dataset
+ * output nondeterministic, like `dataset.repartition(2).sample(false, 0.5, 1618)`.
+ * Checkpointing sampled dataset or adding a sort before sampling can help make the dataset
+ * deterministic.
+ */
+class NMF(override val uid: String) extends Estimator[NMFModel] with NMFParams
+ with DefaultParamsWritable {
+
+ import org.apache.spark.ml.recommendation.NMF.Rating
+
+ def this() = this(Identifiable.randomUID("nmf"))
+
+ /** @group setParam */
+ def setRank(value: Int): this.type = set(rank, value)
+
+ /** @group setParam */
+ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
+
+ /** @group setParam */
+ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
+
+ /** @group setParam */
+ def setUserCol(value: String): this.type = set(userCol, value)
+
+ /** @group setParam */
+ def setItemCol(value: String): this.type = set(itemCol, value)
+
+ /** @group setParam */
+ def setRatingCol(value: String): this.type = set(ratingCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group expertSetParam */
+ def setIntermediateStorageLevel(value: String): this.type = set(intermediateStorageLevel, value)
+
+ /** @group expertSetParam */
+ def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)
+
+ /** @group expertSetParam */
+ def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
+ /**
+ * Set block size for stacking input data in matrices.
+ * Default is 4096.
+ *
+ * @group expertSetParam
+ */
+ def setBlockSize(value: Int): this.type = set(blockSize, value)
+
+ /**
+ * Sets both numUserBlocks and numItemBlocks to the specific value.
+ *
+ * @group setParam
+ */
+ def setNumBlocks(value: Int): this.type = {
+ setNumUserBlocks(value)
+ setNumItemBlocks(value)
+ this
+ }
+
+ override def fit(dataset: Dataset[_]): NMFModel = instrumented { instr =>
+ transformSchema(dataset.schema)
+ import dataset.sparkSession.implicits._
+
+ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
+
+ val ratings = dataset
+ .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r)
+ .rdd
+ .map { row =>
+ Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ }
+
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ instr.logParams(this, rank, numUserBlocks, numItemBlocks, userCol,
+ itemCol, ratingCol, predictionCol, maxIter, regParam, checkpointInterval,
+ seed, intermediateStorageLevel, finalStorageLevel, blockSize)
+
+
+ val (userFactors, itemFactors) = NMF.train(ratings, rank = $(rank),
+ ($(numUserBlocks), $(numItemBlocks)),
+ maxIter = $(maxIter), regParam = $(regParam),
+ intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
+ finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
+ checkpointInterval = $(checkpointInterval), seed = $(seed))
+ val userDF = userFactors.toDF("id", "features")
+ val itemDF = itemFactors.toDF("id", "features")
+ val model = new NMFModel(uid, $(rank), userDF, itemDF).setBlockSize($(blockSize))
+ .setParent(this)
+ copyValues(model)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ override def copy(extra: ParamMap): NMF = defaultCopy(extra)
+}
+
+
+/**
+ * An implementation of NMF that supports generic ID types, specialized for Int and Long. This is
+ * exposed as a developer API for users who do need other ID types. But it is not recommended
+ * because it increases the shuffle size and memory requirement during training. For simplicity,
+ * users and items must have the same type. The number of distinct users/items should be smaller
+ * than 2 billion.
+ */
+object NMF extends DefaultParamsReadable[NMF] with Logging {
+
+ /**
+ * Rating class for better code readability.
+ */
+ case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
+
+ override def load(path: String): NMF = super.load(path)
+
+ /** Trait for least squares solvers applied to the normal equation. */
+ private[recommendation] trait LeastSquaresNESolver extends Serializable {
+ /** Solves a least squares problem with regularization (possibly with other constraints). */
+ def solve(ne: NormalEquation, lambda: Double): Array[Float]
+ }
+
+ /** NNLS solver. */
+ private[recommendation] class NNLSSolver extends LeastSquaresNESolver {
+ private var rank: Int = -1
+ private var workspace: NNLS.Workspace = _
+ private var ata: Array[Double] = _
+ private var initialized: Boolean = false
+
+ private def initialize(rank: Int): Unit = {
+ if (!initialized) {
+ this.rank = rank
+ workspace = NNLS.createWorkspace(rank)
+ ata = new Array[Double](rank * rank)
+ initialized = true
+ } else {
+ require(this.rank == rank)
+ }
+ }
+
+ /**
+ * Solves a nonnegative least squares problem with L2 regularization:
+ *
+ * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^
+ * subject to x >= 0
+ */
+ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
+ val rank = ne.k
+ initialize(rank)
+ fillAtA(ne.ata, lambda)
+ val x = NNLS.solve(ata, ne.atb, workspace)
+ ne.reset()
+ x.map(x => x.toFloat)
+ }
+
+ /**
+ * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
+ * matrix that it represents, storing it into destMatrix.
+ */
+ private def fillAtA(triAtA: Array[Double], lambda: Double): Unit = {
+ var i = 0
+ var pos = 0
+ var a = 0.0
+ while (i < rank) {
+ var j = 0
+ while (j <= i) {
+ a = triAtA(pos)
+ ata(i * rank + j) = a
+ ata(j * rank + i) = a
+ pos += 1
+ j += 1
+ }
+ ata(i * rank + i) += lambda
+ i += 1
+ }
+ }
+ }
+
+ /**
+ * Representing a normal equation to solve the following weighted least squares problem:
+ *
+ * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - d,,i,,)^2^ + lambda * x^T^ x.
+ *
+ * Its normal equation is given by
+ *
+ * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - d,,i,, a,,i,,) + lambda * x = 0.
+ *
+ * Distributing and letting b,,i,, = c,,i,, * d,,i,,
+ *
+ * \sum,,i,, c,,i,, a,,i,, a,,i,,^T^ x - b,,i,, a,,i,, + lambda * x = 0.
+ */
+ private[recommendation] class NormalEquation(val k: Int) extends Serializable {
+
+ /** Number of entries in the upper triangular part of a k-by-k matrix. */
+ val triK = k * (k + 1) / 2
+ /** A^T^ * A */
+ val ata = new Array[Double](triK)
+ /** A^T^ * b */
+ val atb = new Array[Double](k)
+
+ private val da = new Array[Double](k)
+ private val upper = "U"
+
+ private def copyToDouble(a: Array[Float]): Unit = {
+ var i = 0
+ while (i < k) {
+ da(i) = a(i)
+ i += 1
+ }
+ }
+
+ /** Adds an observation. */
+ def add(a: Array[Float], b: Double, c: Double = 1.0): NormalEquation = {
+ require(c >= 0.0)
+ require(a.length == k)
+ copyToDouble(a)
+ BLAS.nativeBLAS.dspr(upper, k, c, da, 1, ata)
+ if (b != 0.0) {
+ BLAS.nativeBLAS.daxpy(k, b, da, 1, atb, 1)
+ }
+ this
+ }
+
+ /** Merges another normal equation object. */
+ def merge(other: NormalEquation): NormalEquation = {
+ require(other.k == k)
+ BLAS.nativeBLAS.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
+ BLAS.nativeBLAS.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
+ this
+ }
+
+ /** Resets everything to zero, which should be called after each solve. */
+ def reset(): Unit = {
+ ju.Arrays.fill(ata, 0.0)
+ ju.Arrays.fill(atb, 0.0)
+ }
+ }
+
+ /**
+ * Implementation of the NMF algorithm.
+ * @param ratings rating data
+ * @param rank rank
+ * @param numBlocks (numUserBlocks, numItemBlocks)
+ * @param maxIter max iterations
+ * @param regParam regularization parameter
+ * @param intermediateRDDStorageLevel intermediate RDD StorageLevel
+ * @param finalRDDStorageLevel final RDD StorageLevel
+ * @param checkpointInterval checkpoint interval
+ * @param seed seed for initialize factors
+ * @param ord order of ID
+ * @tparam ID class tag
+ * @return
+ */
+ def train[ID: ClassTag](
+ ratings: RDD[Rating[ID]],
+ rank: Int = 10,
+ numBlocks: (Int, Int) = (10, 10),
+ maxIter: Int = 10,
+ regParam: Double = 0.1,
+ intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ checkpointInterval: Int = 10,
+ seed: Long = 0L)(
+ implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
+
+ require(!ratings.isEmpty(), s"No ratings available from $ratings")
+ require(intermediateRDDStorageLevel != StorageLevel.NONE,
+ "NMF is not designed to run without persisting intermediate RDDs.")
+
+ val sc = ratings.sparkContext
+
+ // Precompute the rating dependencies of each partition
+ val userPart = new NMFPartitioner(numBlocks._1)
+ val itemPart = new NMFPartitioner(numBlocks._2)
+ val blockRatings = partitionRatings(ratings, userPart, itemPart)
+ .persist(intermediateRDDStorageLevel)
+ val (userInBlocks, userOutBlocks) =
+ makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
+ userOutBlocks.count() // materialize blockRatings and user blocks
+ val swappedBlockRatings = blockRatings.map {
+ case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
+ ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
+ }
+ val (itemInBlocks, itemOutBlocks) =
+ makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
+ itemOutBlocks.count() // materialize item blocks
+
+ // Encoders for storing each user/item's partition ID and index within its partition using a
+ // single integer; used as an optimization
+ val userNMFLocalIndexEncoder = new NMFLocalIndexEncoder(userPart.numPartitions)
+ val itemNMFLocalIndexEncoder = new NMFLocalIndexEncoder(itemPart.numPartitions)
+
+ // These are the user and item factor matrices that, once trained, are multiplied together to
+ // estimate the rating matrix. The two matrices are stored in RDDs, partitioned by column such
+ // that each factor column resides on the same Spark worker as its corresponding user or item.
+ val seedGen = new XORShiftRandom(seed)
+ var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
+ var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
+
+ var previousCheckpointFile: Option[String] = None
+ val shouldCheckpoint: Int => Boolean = (iter) =>
+ sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0)
+ val deletePreviousCheckpointFile: () => Unit = () =>
+ previousCheckpointFile.foreach { file =>
+ try {
+ val checkpointFile = new Path(file)
+ checkpointFile.getFileSystem(sc.hadoopConfiguration).delete(checkpointFile, true)
+ } catch {
+ case e: IOException =>
+ logWarning(s"Cannot delete checkpoint file $file:", e)
+ }
+ }
+
+ val isHighDimensional = chooseSolver(ratings.sparkContext)
+
+ val computeFactors = if (isHighDimensional) {
+ logInfo("high dimensional branch")
+ (srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: NMFLocalIndexEncoder) =>
+ computeFactorsACD(srcFactorBlocks, srcOutBlocks, dstInBlocks, rank, regParam, srcEncoder)
+ } else {
+ logInfo("low dimensional branch")
+ val solver = new NNLSSolver()
+ (srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: NMFLocalIndexEncoder) =>
+ computeFactorsNNLS(
+ srcFactorBlocks, srcOutBlocks, dstInBlocks, rank, regParam, srcEncoder, solver)
+ }
+
+ var previousCachedItemFactors: Option[RDD[(Int, FactorBlock)]] = None
+ for (iter <- 0 until maxIter) {
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks,
+ rank, regParam, userNMFLocalIndexEncoder)
+ if (shouldCheckpoint(iter)) {
+ itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
+ itemFactors.checkpoint()
+ itemFactors.count() // checkpoint item factors and cut lineage
+ itemFactors.cleanShuffleDependencies()
+ deletePreviousCheckpointFile()
+ previousCachedItemFactors.foreach(_.unpersist())
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ previousCachedItemFactors = Option(itemFactors)
+ }
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks,
+ rank, regParam, itemNMFLocalIndexEncoder)
+ }
+
+ val userIdAndFactors = userInBlocks
+ .mapValues(_.srcIds)
+ .join(userFactors)
+ .mapPartitions({ items =>
+ items.flatMap { case (_, (ids, factors)) =>
+ ids.iterator.zip(factors.iterator)
+ }
+ // Preserve the partitioning because IDs are consistent
+ // with the partitioners in userInBlocks
+ // and userFactors.
+ }, preservesPartitioning = true)
+ .setName("userFactors")
+ .persist(finalRDDStorageLevel)
+ val itemIdAndFactors = itemInBlocks
+ .mapValues(_.srcIds)
+ .join(itemFactors)
+ .mapPartitions({ items =>
+ items.flatMap { case (_, (ids, factors)) =>
+ ids.iterator.zip(factors.iterator)
+ }
+ }, preservesPartitioning = true)
+ .setName("itemFactors")
+ .persist(finalRDDStorageLevel)
+ if (finalRDDStorageLevel != StorageLevel.NONE) {
+ userIdAndFactors.count()
+ userInBlocks.unpersist()
+ userOutBlocks.unpersist()
+ itemOutBlocks.unpersist()
+ blockRatings.unpersist()
+ itemIdAndFactors.count()
+ itemFactors.unpersist()
+ itemInBlocks.unpersist()
+ }
+ (userIdAndFactors, itemIdAndFactors)
+ }
+
+ /**
+ * choose solver
+ * @param sc sparkContext
+ * @return choose high dimensional solver or not
+ */
+ def chooseSolver(sc: SparkContext): Boolean = {
+ val optLevelParamKey = "spark.boostkit.ml.nmf.highDimensional"
+ val optLevelStr = sc.conf.getOption(optLevelParamKey).getOrElse("true")
+ var optLevel: Boolean = true
+ try {
+ optLevel = optLevelStr.toBoolean
+ } catch {
+ case ex: Exception =>
+ throw new IllegalArgumentException(s"Parse boostkit parameter" +
+ s"($optLevelParamKey) failed, Error reason: ${ex.getMessage} cannot be casted to Boolean")
+ }
+ optLevel
+ }
+
+ /**
+ * Factor block that stores factors (Array[Float]) in an Array.
+ */
+ private type FactorBlock = Array[Array[Float]]
+
+ /**
+ * A mapping of the columns of the items factor matrix that are needed when calculating each row
+ * of the users factor matrix, and vice versa.
+ *
+ * Specifically, when calculating a user factor vector, since only those columns of the items
+ * factor matrix that correspond to the items that that user has rated are needed, we can avoid
+ * having to repeatedly copy the entire items factor matrix to each worker later in the algorithm
+ * by precomputing these dependencies for all users, storing them in an RDD of `OutBlock`s. The
+ * items' dependencies on the columns of the users factor matrix is computed similarly.
+ *
+ * =Example=
+ *
+ * Using the example provided in the `InBlock` Scaladoc, `userOutBlocks` would look like the
+ * following:
+ *
+ * {{{
+ * userOutBlocks.collect() == Seq(
+ * 0 -> Array(Array(0, 1), Array(0, 1)),
+ * 1 -> Array(Array(0), Array(0))
+ * )
+ * }}}
+ *
+ * Each value in this map-like sequence is of type `Array[Array[Int]]`. The values in the
+ * inner array are the ranks of the sorted user IDs in that partition; so in the example above,
+ * `Array(0, 1)` in partition 0 refers to user IDs 0 and 6, since when all unique user IDs in
+ * partition 0 are sorted, 0 is the first ID and 6 is the second. The position of each inner
+ * array in its enclosing outer array denotes the partition number to which item IDs map; in the
+ * example, the first `Array(0, 1)` is in position 0 of its outer array, denoting item IDs that
+ * map to partition 0.
+ *
+ * In summary, the data structure encodes the following information:
+ *
+ * * There are ratings with user IDs 0 and 6 (encoded in `Array(0, 1)`, where 0 and 1 are the
+ * indices of the user IDs 0 and 6 on partition 0) whose item IDs map to partitions 0 and 1
+ * (represented by the fact that `Array(0, 1)` appears in both the 0th and 1st positions).
+ *
+ * * There are ratings with user ID 3 (encoded in `Array(0)`, where 0 is the index of the user
+ * ID 3 on partition 1) whose item IDs map to partitions 0 and 1 (represented by the fact that
+ * `Array(0)` appears in both the 0th and 1st positions).
+ */
+ private type OutBlock = Array[Array[Int]]
+
+ /**
+ * In-link block for computing user and item factor matrices.
+ *
+ * The NMF algorithm partitions the columns of the users factor matrix evenly among Spark workers.
+ * Since each column of the factor matrix is calculated using the known ratings of the correspond-
+ * ing user, and since the ratings don't change across iterations, the NMF algorithm preshuffles
+ * the ratings to the appropriate partitions, storing them in `InBlock` objects.
+ *
+ * The ratings shuffled by item ID are computed similarly and also stored in `InBlock` objects.
+ * Note that this means every rating is stored twice, once as shuffled by user ID and once by item
+ * ID. This is a necessary tradeoff, since in general a rating will not be on the same worker
+ * when partitioned by user as by item.
+ *
+ * =Example=
+ *
+ * Say we have a small collection of eight items to offer the seven users in our application. We
+ * have some known ratings given by the users, as seen in the matrix below:
+ *
+ * {{{
+ * Items
+ * 0 1 2 3 4 5 6 7
+ * +---+---+---+---+---+---+---+---+
+ * 0 | |0.1| | |0.4| | |0.7|
+ * +---+---+---+---+---+---+---+---+
+ * 1 | | | | | | | | |
+ * +---+---+---+---+---+---+---+---+
+ * U 2 | | | | | | | | |
+ * s +---+---+---+---+---+---+---+---+
+ * e 3 | |3.1| | |3.4| | |3.7|
+ * r +---+---+---+---+---+---+---+---+
+ * s 4 | | | | | | | | |
+ * +---+---+---+---+---+---+---+---+
+ * 5 | | | | | | | | |
+ * +---+---+---+---+---+---+---+---+
+ * 6 | |6.1| | |6.4| | |6.7|
+ * +---+---+---+---+---+---+---+---+
+ * }}}
+ *
+ * The ratings are represented as an RDD, passed to the `partitionRatings` method as the `ratings`
+ * parameter:
+ *
+ * {{{
+ * ratings.collect() == Seq(
+ * Rating(0, 1, 0.1f),
+ * Rating(0, 4, 0.4f),
+ * Rating(0, 7, 0.7f),
+ * Rating(3, 1, 3.1f),
+ * Rating(3, 4, 3.4f),
+ * Rating(3, 7, 3.7f),
+ * Rating(6, 1, 6.1f),
+ * Rating(6, 4, 6.4f),
+ * Rating(6, 7, 6.7f)
+ * )
+ * }}}
+ *
+ * Say that we are using two partitions to calculate each factor matrix:
+ *
+ * {{{
+ * val userPart = new NMFPartitioner(2)
+ * val itemPart = new NMFPartitioner(2)
+ * val blockRatings = partitionRatings(ratings, userPart, itemPart)
+ * }}}
+ *
+ * Ratings are mapped to partitions using the user/item IDs modulo the number of partitions. With
+ * two partitions, ratings with even-valued user IDs are shuffled to partition 0 while those with
+ * odd-valued user IDs are shuffled to partition 1:
+ *
+ * {{{
+ * userInBlocks.collect() == Seq(
+ * 0 -> Seq(
+ * // Internally, the class stores the ratings in a more optimized format than
+ * // a sequence of `Rating`s, but for clarity we show it as such here.
+ * Rating(0, 1, 0.1f),
+ * Rating(0, 4, 0.4f),
+ * Rating(0, 7, 0.7f),
+ * Rating(6, 1, 6.1f),
+ * Rating(6, 4, 6.4f),
+ * Rating(6, 7, 6.7f)
+ * ),
+ * 1 -> Seq(
+ * Rating(3, 1, 3.1f),
+ * Rating(3, 4, 3.4f),
+ * Rating(3, 7, 3.7f)
+ * )
+ * )
+ * }}}
+ *
+ * Similarly, ratings with even-valued item IDs are shuffled to partition 0 while those with
+ * odd-valued item IDs are shuffled to partition 1:
+ *
+ * {{{
+ * itemInBlocks.collect() == Seq(
+ * 0 -> Seq(
+ * Rating(0, 4, 0.4f),
+ * Rating(3, 4, 3.4f),
+ * Rating(6, 4, 6.4f)
+ * ),
+ * 1 -> Seq(
+ * Rating(0, 1, 0.1f),
+ * Rating(0, 7, 0.7f),
+ * Rating(3, 1, 3.1f),
+ * Rating(3, 7, 3.7f),
+ * Rating(6, 1, 6.1f),
+ * Rating(6, 7, 6.7f)
+ * )
+ * )
+ * }}}
+ *
+ * @param srcIds src ids (ordered)
+ * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and
+ * ratings are associated with srcIds(i).
+ * @param dstEncodedIndices encoded dst indices
+ * @param ratings ratings
+ * @see [[NMFLocalIndexEncoder]]
+ */
+ private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
+ srcIds: Array[ID],
+ dstPtrs: Array[Int],
+ dstEncodedIndices: Array[Int],
+ ratings: Array[Float]) {
+ /** Size of the block. */
+ def size: Int = ratings.length
+ require(dstEncodedIndices.length == size)
+ require(dstPtrs.length == srcIds.length + 1)
+ }
+
+ /**
+ * Initializes factors randomly given the in-link blocks.
+ *
+ * @param inBlocks in-link blocks
+ * @param rank rank
+ * @return initialized factor blocks
+ */
+ private def initialize[ID](
+ inBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int,
+ seed: Long): RDD[(Int, FactorBlock)] = {
+ // Choose a unit vector uniformly at random from the unit sphere, but from the
+ // "first quadrant" where all elements are nonnegative. This can be done by choosing
+ // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
+ // This appears to create factorizations that have a slightly better reconstruction
+ // (<1%) compared picking elements uniformly at random in [0,1].
+ inBlocks.mapPartitions({ iter =>
+ iter.map {
+ case (srcBlockId, inBlock) =>
+ val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId))
+ val factors = Array.fill(inBlock.srcIds.length) {
+ val factor = Array.fill(rank)(random.nextGaussian().toFloat)
+ val nrm = BLAS.nativeBLAS.snrm2(rank, factor, 1)
+ BLAS.nativeBLAS.sscal(rank, 1.0f / nrm, factor, 1)
+ factor.map(x => {math.max(0, x)})
+ }
+ (srcBlockId, factors)
+ }
+ }, preservesPartitioning = true)
+ }
+
+ /**
+ * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
+ */
+ private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag](
+ srcIds: Array[ID],
+ dstIds: Array[ID],
+ ratings: Array[Float]) {
+ /** Size of the block. */
+ def size: Int = srcIds.length
+ require(dstIds.length == srcIds.length)
+ require(ratings.length == srcIds.length)
+ }
+
+ /**
+ * Builder for [[RatingBlock]]. `mutable.ArrayBuilder` is used to avoid boxing/unboxing.
+ */
+ private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag]
+ extends Serializable {
+
+ private val srcIds = mutable.ArrayBuilder.make[ID]
+ private val dstIds = mutable.ArrayBuilder.make[ID]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+ var size = 0
+
+ /** Adds a rating. */
+ def add(r: Rating[ID]): this.type = {
+ size += 1
+ srcIds += r.user
+ dstIds += r.item
+ ratings += r.rating
+ this
+ }
+
+ /** Merges another [[RatingBlockBuilder]]. */
+ def merge(other: RatingBlock[ID]): this.type = {
+ size += other.srcIds.length
+ srcIds ++= other.srcIds
+ dstIds ++= other.dstIds
+ ratings ++= other.ratings
+ this
+ }
+
+ /** Builds a [[RatingBlock]]. */
+ def build(): RatingBlock[ID] = {
+ RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result())
+ }
+ }
+
+ /**
+ * Groups an RDD of [[Rating]]s by the user partition and item partition to which each `Rating`
+ * maps according to the given partitioners. The returned pair RDD holds the ratings, encoded in
+ * a memory-efficient format but otherwise unchanged, keyed by the (user partition ID, item
+ * partition ID) pair.
+ *
+ * Performance note: This is an expensive operation that performs an RDD shuffle.
+ *
+ * Implementation note: This implementation produces the same result as the following but
+ * generates fewer intermediate objects:
+ *
+ * {{{
+ * ratings.map { r =>
+ * ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r)
+ * }.aggregateByKey(new RatingBlockBuilder)(
+ * seqOp = (b, r) => b.add(r),
+ * combOp = (b0, b1) => b0.merge(b1.build()))
+ * .mapValues(_.build())
+ * }}}
+ *
+ * @param ratings raw ratings
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
+ */
+ private def partitionRatings[ID: ClassTag](
+ ratings: RDD[Rating[ID]],
+ srcPart: Partitioner,
+ dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = {
+ val numPartitions = srcPart.numPartitions * dstPart.numPartitions
+ ratings.mapPartitions { iter =>
+ val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID])
+ iter.flatMap { r =>
+ val srcBlockId = srcPart.getPartition(r.user)
+ val dstBlockId = dstPart.getPartition(r.item)
+ val idx = srcBlockId + srcPart.numPartitions * dstBlockId
+ val builder = builders(idx)
+ builder.add(r)
+ if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k
+ builders(idx) = new RatingBlockBuilder
+ Iterator.single(((srcBlockId, dstBlockId), builder.build()))
+ } else {
+ Iterator.empty
+ }
+ } ++ {
+ builders.iterator.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
+ val srcBlockId = idx % srcPart.numPartitions
+ val dstBlockId = idx / srcPart.numPartitions
+ ((srcBlockId, dstBlockId), block.build())
+ }
+ }
+ }.groupByKey().mapValues { blocks =>
+ val builder = new RatingBlockBuilder[ID]
+ blocks.foreach(builder.merge)
+ builder.build()
+ }.setName("ratingBlocks")
+ }
+
+ /**
+ * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
+ *
+ * @param encoder encoder for dst indices
+ */
+ private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
+ encoder: NMFLocalIndexEncoder)(
+ implicit ord: Ordering[ID]) {
+
+ private val srcIds = mutable.ArrayBuilder.make[ID]
+ private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+
+ /**
+ * Adds a dst block of (srcId, dstLocalIndex, rating) tuples.
+ *
+ * @param dstBlockId dst block ID
+ * @param srcIds original src IDs
+ * @param dstLocalIndices dst local indices
+ * @param ratings ratings
+ */
+ def add(
+ dstBlockId: Int,
+ srcIds: Array[ID],
+ dstLocalIndices: Array[Int],
+ ratings: Array[Float]): this.type = {
+ val sz = srcIds.length
+ require(dstLocalIndices.length == sz)
+ require(ratings.length == sz)
+ this.srcIds ++= srcIds
+ this.ratings ++= ratings
+ var j = 0
+ while (j < sz) {
+ this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j))
+ j += 1
+ }
+ this
+ }
+
+ /** Builds a [[UncompressedInBlock]]. */
+ def build(): UncompressedInBlock[ID] = {
+ new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
+ }
+ }
+
+ /**
+ * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
+ */
+ private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag](
+ val srcIds: Array[ID],
+ val dstEncodedIndices: Array[Int],
+ val ratings: Array[Float])(
+ implicit ord: Ordering[ID]) {
+
+ /** Size the of block. */
+ def length: Int = srcIds.length
+
+ /**
+ * Compresses the block into an `InBlock`. The algorithm is the same as converting a sparse
+ * matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
+ * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
+ */
+ def compress(): InBlock[ID] = {
+ val sz = length
+ assert(sz > 0, "Empty in-link block should not exist.")
+ sort()
+ val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID]
+ val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
+ var preSrcId = srcIds(0)
+ uniqueSrcIdsBuilder += preSrcId
+ var curCount = 1
+ var i = 1
+ while (i < sz) {
+ val srcId = srcIds(i)
+ if (srcId != preSrcId) {
+ uniqueSrcIdsBuilder += srcId
+ dstCountsBuilder += curCount
+ preSrcId = srcId
+ curCount = 0
+ }
+ curCount += 1
+ i += 1
+ }
+ dstCountsBuilder += curCount
+ val uniqueSrcIds = uniqueSrcIdsBuilder.result()
+ val numUniqueSrdIds = uniqueSrcIds.length
+ val dstCounts = dstCountsBuilder.result()
+ val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
+ var sum = 0
+ i = 0
+ while (i < numUniqueSrdIds) {
+ sum += dstCounts(i)
+ i += 1
+ dstPtrs(i) = sum
+ }
+ InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
+ }
+
+ private def sort(): Unit = {
+ val sz = length
+ // Since there might be interleaved log messages, we insert a unique id for easy pairing.
+ val sortId = Utils.random.nextInt()
+ logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
+ val start = System.nanoTime()
+ val sorter = new Sorter(new UncompressedInBlockSort[ID])
+ sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]])
+ val duration = (System.nanoTime() - start) / 1e9
+ logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
+ }
+ }
+
+ /**
+ * A wrapper that holds a primitive key.
+ *
+ * @see [[UncompressedInBlockSort]]
+ */
+ private class KeyWrapper[@specialized(Int, Long) ID: ClassTag](
+ implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] {
+
+ var key: ID = _
+
+ override def compare(that: KeyWrapper[ID]): Int = {
+ ord.compare(key, that.key)
+ }
+
+ def setKey(key: ID): this.type = {
+ this.key = key
+ this
+ }
+ }
+
+ /**
+ * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
+ */
+ private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag](
+ implicit ord: Ordering[ID])
+ extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] {
+
+ override def newKey(): KeyWrapper[ID] = new KeyWrapper()
+
+ override def getKey(
+ data: UncompressedInBlock[ID],
+ pos: Int,
+ reuse: KeyWrapper[ID]): KeyWrapper[ID] = {
+ if (reuse == null) {
+ new KeyWrapper().setKey(data.srcIds(pos))
+ } else {
+ reuse.setKey(data.srcIds(pos))
+ }
+ }
+
+ override def getKey(
+ data: UncompressedInBlock[ID],
+ pos: Int): KeyWrapper[ID] = {
+ getKey(data, pos, null)
+ }
+
+ private def swapElements[@specialized(Int, Float) T](
+ data: Array[T],
+ pos0: Int,
+ pos1: Int): Unit = {
+ val tmp = data(pos0)
+ data(pos0) = data(pos1)
+ data(pos1) = tmp
+ }
+
+ override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = {
+ swapElements(data.srcIds, pos0, pos1)
+ swapElements(data.dstEncodedIndices, pos0, pos1)
+ swapElements(data.ratings, pos0, pos1)
+ }
+
+ override def copyRange(
+ src: UncompressedInBlock[ID],
+ srcPos: Int,
+ dst: UncompressedInBlock[ID],
+ dstPos: Int,
+ length: Int): Unit = {
+ System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
+ System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
+ System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
+ }
+
+ override def allocate(length: Int): UncompressedInBlock[ID] = {
+ new UncompressedInBlock(
+ new Array[ID](length), new Array[Int](length), new Array[Float](length))
+ }
+
+ override def copyElement(
+ src: UncompressedInBlock[ID],
+ srcPos: Int,
+ dst: UncompressedInBlock[ID],
+ dstPos: Int): Unit = {
+ dst.srcIds(dstPos) = src.srcIds(srcPos)
+ dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
+ dst.ratings(dstPos) = src.ratings(srcPos)
+ }
+ }
+
+ /**
+ * Creates in-blocks and out-blocks from rating blocks.
+ *
+ * @param prefix prefix for in/out-block names
+ * @param ratingBlocks rating blocks
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ * @return (in-blocks, out-blocks)
+ */
+ private def makeBlocks[ID: ClassTag](
+ prefix: String,
+ ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
+ srcPart: Partitioner,
+ dstPart: Partitioner,
+ storageLevel: StorageLevel)(
+ implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
+ val inBlocks = ratingBlocks.map {
+ case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
+ // The implementation is a faster version of
+ // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
+ val start = System.nanoTime()
+ val dstIdSet = new OpenHashSet[ID](1 << 20)
+ dstIds.foreach(dstIdSet.add)
+ val sortedDstIds = new Array[ID](dstIdSet.size)
+ var i = 0
+ var pos = dstIdSet.nextPos(0)
+ while (pos != -1) {
+ sortedDstIds(i) = dstIdSet.getValue(pos)
+ pos = dstIdSet.nextPos(pos + 1)
+ i += 1
+ }
+ assert(i == dstIdSet.size)
+ Sorting.quickSort(sortedDstIds)
+ val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length)
+ i = 0
+ while (i < sortedDstIds.length) {
+ dstIdToLocalIndex.update(sortedDstIds(i), i)
+ i += 1
+ }
+ logDebug(
+ "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
+ val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
+ (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
+ }.groupByKey(new NMFPartitioner(srcPart.numPartitions))
+ .mapValues { iter =>
+ val builder =
+ new UncompressedInBlockBuilder[ID](new NMFLocalIndexEncoder(dstPart.numPartitions))
+ iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
+ builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
+ }
+ builder.build().compress()
+ }.setName(prefix + "InBlocks")
+ .persist(storageLevel)
+ val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
+ val encoder = new NMFLocalIndexEncoder(dstPart.numPartitions)
+ val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
+ var i = 0
+ val seen = new Array[Boolean](dstPart.numPartitions)
+ while (i < srcIds.length) {
+ var j = dstPtrs(i)
+ ju.Arrays.fill(seen, false)
+ while (j < dstPtrs(i + 1)) {
+ val dstBlockId = encoder.blockId(dstEncodedIndices(j))
+ if (!seen(dstBlockId)) {
+ activeIds(dstBlockId) += i // add the local index in this out-block
+ seen(dstBlockId) = true
+ }
+ j += 1
+ }
+ i += 1
+ }
+ activeIds.map { x =>
+ x.result()
+ }
+ }.setName(prefix + "OutBlocks")
+ .persist(storageLevel)
+ (inBlocks, outBlocks)
+ }
+
+ /**
+ * Compute dst factors by constructing and solving least square problems.
+ *
+ * @param srcFactorBlocks src factors
+ * @param srcOutBlocks src out-blocks
+ * @param dstInBlocks dst in-blocks
+ * @param rank rank
+ * @param regParam regularization constant
+ * @param srcEncoder encoder for src local indices
+ * @param solver solver for least squares problems
+ * @return dst factors
+ */
+ private def computeFactorsNNLS[ID](
+ srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: NMFLocalIndexEncoder,
+ solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
+ val numSrcBlocks = srcFactorBlocks.partitions.length
+ val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
+ case (srcBlockId, (srcOutBlock, srcFactors)) =>
+ srcOutBlock.iterator.zipWithIndex.map { case (activeIndices, dstBlockId) =>
+ (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
+ }
+ }
+ val merged = srcOut.groupByKey(new NMFPartitioner(dstInBlocks.partitions.length))
+
+ // SPARK-28927: Nondeterministic RDDs causes inconsistent in/out blocks in case of rerun.
+ // It can cause runtime error when matching in/out user/item blocks.
+ val isBlockRDDNondeterministic =
+ dstInBlocks.outputDeterministicLevel == DeterministicLevel.INDETERMINATE ||
+ srcOutBlocks.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
+
+ dstInBlocks.join(merged).mapValues {
+ case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
+ val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
+ srcFactors.foreach { case (srcBlockId, factors) =>
+ sortedSrcFactors(srcBlockId) = factors
+ }
+ val dstFactors = new Array[Array[Float]](dstIds.length)
+ var j = 0
+ val ls = new NormalEquation(rank)
+ while (j < dstIds.length) {
+ ls.reset()
+ var i = srcPtrs(j)
+ var numExplicits = 0
+ while (i < srcPtrs(j + 1)) {
+ val encoded = srcEncodedIndices(i)
+ val blockId = srcEncoder.blockId(encoded)
+ val localIndex = srcEncoder.localIndex(encoded)
+ var srcFactor: Array[Float] = null
+ try {
+ srcFactor = sortedSrcFactors(blockId)(localIndex)
+ } catch {
+ case a: ArrayIndexOutOfBoundsException if isBlockRDDNondeterministic =>
+ val errMsg = "A failure detected when matching In/Out blocks of users/items. " +
+ "Because at least one In/Out block RDD is found to be nondeterministic now, " +
+ "the issue is probably caused by nondeterministic input data. You can try to " +
+ "checkpoint training data to make it deterministic. If you do `repartition` + " +
+ "`sample` or `randomSplit`, you can also try to sort it before `sample` or " +
+ "`randomSplit` to make it deterministic."
+ throw new SparkException(errMsg, a)
+ }
+ val rating = ratings(i)
+ ls.add(srcFactor, rating)
+ numExplicits += 1
+ i += 1
+ }
+ // Weight lambda by the number of explicit ratings based on the ALS-WR paper.
+ dstFactors(j) = solver.solve(ls, numExplicits * regParam)
+ j += 1
+ }
+ dstFactors
+ }
+ }
+
+ /**
+ * update factors by NMFSolver
+ * @param srcFactorBlocks src factors
+ * @param srcOutBlocks src out-blocks
+ * @param dstInBlocks dst in-blocks
+ * @param rank rank
+ * @param regParam regularization constant
+ * @param srcEncoder encoder for src local indices
+ * @return updated factors
+ */
+ private def computeFactorsACD[ID](
+ srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: NMFLocalIndexEncoder): RDD[(Int, FactorBlock)] = {
+ val numSrcBlocks = srcFactorBlocks.partitions.length
+ val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
+ case (srcBlockId, (srcOutBlock, srcFactors)) =>
+ srcOutBlock.iterator.zipWithIndex.map { case (activeIndices, dstBlockId) =>
+ (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
+ }
+ }
+ val merged = srcOut.groupByKey(new NMFPartitioner(dstInBlocks.partitions.length))
+
+ dstInBlocks.join(merged).mapValues {
+ case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
+ NMFSolver.updateFactors(
+ dstIds,
+ srcPtrs,
+ srcEncodedIndices,
+ ratings,
+ srcFactors,
+ numSrcBlocks,
+ rank,
+ regParam.toFloat,
+ srcEncoder)
+ }
+ }
+
+ /**
+ * Partitioner used by NMF. We require that getPartition is a projection. That is, for any key k,
+ * we have getPartition(getPartition(k)) = getPartition(k). Since the default HashPartitioner
+ * satisfies this requirement, we simply use a type alias here.
+ */
+ private[recommendation] type NMFPartitioner = org.apache.spark.HashPartitioner
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 825bb857f2877337d116b527790d67b4cd7b396a..56e0128fb04d81febd58d9d1124f34192cf85f97 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -294,7 +294,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
+ val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
+ sparkSession.createDataFrame(nodeData).repartition(numDataParts).write.parquet(dataPath)
}
}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
new file mode 100644
index 0000000000000000000000000000000000000000..36818ab66b40871a3f0a580469dfa3765a6a4a98
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -0,0 +1,809 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.util.Random
+
+import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV}
+import breeze.numerics.{sqrt => brzSqrt}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{PredictorParams, StaticUtils}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.linalg.BLAS._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.regression.FactorizationMachines._
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
+import org.apache.spark.mllib.{linalg => OldLinalg}
+import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
+import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.optimization.{Gradient, LBFGSN, SquaredL2Updater, Updater}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for Factorization Machines
+ */
+private[ml] trait FactorizationMachinesParams extends PredictorParams
+ with HasMaxIter with HasStepSize with HasTol with HasSolver with HasSeed
+ with HasFitIntercept with HasRegParam with HasWeightCol {
+
+ /**
+ * Param for dimensionality of the factors (>= 0)
+ * @group param
+ */
+ @Since("3.0.0")
+ final val factorSize: IntParam = new IntParam(this, "factorSize",
+ "Dimensionality of the factor vectors, " +
+ "which are used to get pairwise interactions between variables",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ @Since("3.0.0")
+ final def getFactorSize: Int = $(factorSize)
+
+ /**
+ * Param for whether to fit linear term (aka 1-way term)
+ * @group param
+ */
+ @Since("3.0.0")
+ final val fitLinear: BooleanParam = new BooleanParam(this, "fitLinear",
+ "whether to fit linear term (aka 1-way term)")
+
+ /** @group getParam */
+ @Since("3.0.0")
+ final def getFitLinear: Boolean = $(fitLinear)
+
+ /**
+ * Param for mini-batch fraction, must be in range (0, 1]
+ * @group param
+ */
+ @Since("3.0.0")
+ final val miniBatchFraction: DoubleParam = new DoubleParam(this, "miniBatchFraction",
+ "fraction of the input data set that should be used for one iteration of gradient descent",
+ ParamValidators.inRange(0, 1, false, true))
+
+ /** @group getParam */
+ @Since("3.0.0")
+ final def getMiniBatchFraction: Double = $(miniBatchFraction)
+
+ /**
+ * Param for standard deviation of initial coefficients
+ * @group param
+ */
+ @Since("3.0.0")
+ final val initStd: DoubleParam = new DoubleParam(this, "initStd",
+ "standard deviation of initial coefficients", ParamValidators.gt(0))
+
+ /** @group getParam */
+ @Since("3.0.0")
+ final def getInitStd: Double = $(initStd)
+
+ /**
+ * The solver algorithm for optimization.
+ * Supported options: "gd", "adamW".
+ * Default: "adamW"
+ *
+ * @group param
+ */
+ @Since("3.0.0")
+ final override val solver: Param[String] = new Param[String](this, "solver",
+ "The solver algorithm for optimization. Supported options: " +
+ s"${supportedSolvers.mkString(", ")}. (Default adamW)",
+ ParamValidators.inArray[String](supportedSolvers))
+
+ setDefault(factorSize -> 8, fitIntercept -> true, fitLinear -> true, regParam -> 0.0,
+ miniBatchFraction -> 1.0, initStd -> 0.01, maxIter -> 100, stepSize -> 1.0, tol -> 1E-6,
+ solver -> AdamW)
+}
+
+private[ml] trait FactorizationMachines extends FactorizationMachinesParams {
+
+ private[ml] def initCoefficients(numFeatures: Int): OldVector = {
+ val rnd = new Random($(seed))
+ val initialCoefficients =
+ OldVectors.dense(
+ Array.fill($(factorSize) * numFeatures)(rnd.nextGaussian() * $(initStd)) ++
+ (if ($(fitLinear)) new Array[Double](numFeatures) else Array.emptyDoubleArray) ++
+ (if ($(fitIntercept)) new Array[Double](1) else Array.emptyDoubleArray))
+ initialCoefficients
+ }
+
+ private[ml] def trainImpl(
+ data: RDD[(Double, OldVector)],
+ numFeatures: Int,
+ loss: String
+ ): (Vector, Array[Double]) = {
+
+ // initialize coefficients
+ val initialCoefficients = initCoefficients(numFeatures)
+ val coefficientsSize = initialCoefficients.size
+
+ // optimize coefficients with gradient descent
+ val gradient = parseLoss(loss, $(factorSize), $(fitIntercept), $(fitLinear), numFeatures)
+
+ val updater = parseSolver($(solver), coefficientsSize)
+
+ val optimizer = new LBFGSN(gradient, updater)
+ .setNumIterations($(maxIter))
+ .setRegParam($(regParam))
+ .setConvergenceTol($(tol))
+ val (coefficients, lossHistory) = optimizer.optimizeWithLossReturned(data, initialCoefficients)
+ (coefficients.asML, lossHistory)
+ }
+}
+
+private[ml] object FactorizationMachines {
+
+ /** String name for "gd". */
+ val GD = "gd"
+
+ /** String name for "adamW". */
+ val AdamW = "adamW"
+
+ /** Set of solvers that FactorizationMachines supports. */
+ val supportedSolvers = Array(GD, AdamW)
+
+ /** String name for "logisticLoss". */
+ val LogisticLoss = "logisticLoss"
+
+ /** String name for "squaredError". */
+ val SquaredError = "squaredError"
+
+ /** Set of loss function names that FactorizationMachines supports. */
+ val supportedRegressorLosses = Array(SquaredError)
+ val supportedClassifierLosses = Array(LogisticLoss)
+ val supportedLosses = supportedRegressorLosses ++ supportedClassifierLosses
+
+ def parseSolver(solver: String, coefficientsSize: Int): Updater = {
+ solver match {
+ case GD => new SquaredL2Updater()
+ case AdamW => new AdamWUpdater(coefficientsSize)
+ }
+ }
+
+ def parseLoss(
+ lossFunc: String,
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean,
+ numFeatures: Int
+ ): BaseFactorizationMachinesGradient = {
+
+ lossFunc match {
+ case LogisticLoss =>
+ new LogisticFactorizationMachinesGradient(factorSize, fitIntercept, fitLinear, numFeatures)
+ case SquaredError =>
+ new MSEFactorizationMachinesGradient(factorSize, fitIntercept, fitLinear, numFeatures)
+ case _ => throw new IllegalArgumentException(s"loss function type $lossFunc is invalidation")
+ }
+ }
+
+ def splitCoefficients(
+ coefficients: Vector,
+ numFeatures: Int,
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean
+ ): (Double, Vector, Matrix) = {
+
+ val coefficientsSize = numFeatures * factorSize +
+ (if (fitLinear) numFeatures else 0) + (if (fitIntercept) 1 else 0)
+ require(coefficientsSize == coefficients.size,
+ s"coefficients.size did not match the excepted size ${coefficientsSize}")
+
+ val intercept = if (fitIntercept) coefficients(coefficients.size - 1) else 0.0
+ val linear: Vector = if (fitLinear) {
+ new DenseVector(coefficients.toArray.slice(
+ numFeatures * factorSize, numFeatures * factorSize + numFeatures))
+ } else {
+ Vectors.sparse(numFeatures, Seq.empty)
+ }
+ val factors = new DenseMatrix(numFeatures, factorSize,
+ coefficients.toArray.slice(0, numFeatures * factorSize), true)
+ (intercept, linear, factors)
+ }
+
+ def combineCoefficients(
+ intercept: Double,
+ linear: Vector,
+ factors: Matrix,
+ fitIntercept: Boolean,
+ fitLinear: Boolean
+ ): Vector = {
+
+ val coefficients = factors.toDense.values ++
+ (if (fitLinear) linear.toArray else Array.emptyDoubleArray) ++
+ (if (fitIntercept) Array(intercept) else Array.emptyDoubleArray)
+ new DenseVector(coefficients)
+ }
+
+ def getRawPrediction(
+ features: Vector,
+ intercept: Double,
+ linear: Vector,
+ factors: Matrix
+ ): Double = {
+ var rawPrediction = intercept + features.dot(linear)
+ (0 until factors.numCols).foreach { f =>
+ var sumSquare = 0.0
+ var sum = 0.0
+ features.foreachNonZero { case (index, value) =>
+ val vx = factors(index, f) * value
+ sumSquare += vx * vx
+ sum += vx
+ }
+ rawPrediction += 0.5 * (sum * sum - sumSquare)
+ }
+
+ rawPrediction
+ }
+}
+
+/**
+ * Params for FMRegressor
+ */
+private[regression] trait FMRegressorParams extends FactorizationMachinesParams {
+}
+
+/**
+ * Factorization Machines learning algorithm for regression.
+ * It supports normal gradient descent and AdamW solver.
+ *
+ * The implementation is based upon:
+ *
+ * S. Rendle. "Factorization machines" 2010.
+ *
+ * FM is able to estimate interactions even in problems with huge sparsity
+ * (like advertising and recommendation system).
+ * FM formula is:
+ *
+ * $$
+ * \begin{align}
+ * y = w_0 + \sum\limits^n_{i-1} w_i x_i +
+ * \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j
+ * \end{align}
+ * $$
+ *
+ * First two terms denote global bias and linear term (as same as linear regression),
+ * and last term denotes pairwise interactions term. v_i describes the i-th variable
+ * with k factors.
+ *
+ * FM regression model uses MSE loss which can be solved by gradient descent method, and
+ * regularization terms like L2 are usually added to the loss function to prevent overfitting.
+ */
+@Since("3.0.0")
+class FMRegressor @Since("3.0.0") (
+ @Since("3.0.0") override val uid: String)
+ extends Regressor[Vector, FMRegressor, FMRegressionModel]
+ with FactorizationMachines with FMRegressorParams with DefaultParamsWritable with Logging {
+
+ @Since("3.0.0")
+ def this() = this(Identifiable.randomUID("fmr"))
+
+ /**
+ * Set the dimensionality of the factors.
+ * Default is 8.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setFactorSize(value: Int): this.type = set(factorSize, value)
+
+ /**
+ * Set whether to fit intercept term.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+ /**
+ * Set whether to fit linear term.
+ * Default is true.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setFitLinear(value: Boolean): this.type = set(fitLinear, value)
+
+ /**
+ * Set the L2 regularization parameter.
+ * Default is 0.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /**
+ * Set the mini-batch fraction parameter.
+ * Default is 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value)
+
+ /**
+ * Set the standard deviation of initial coefficients.
+ * Default is 0.01.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setInitStd(value: Double): this.type = set(initStd, value)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Set the initial step size for the first step (like learning rate).
+ * Default is 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Default is 1E-6.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /**
+ * Set the solver algorithm used for optimization.
+ * Supported options: "gd", "adamW".
+ * Default: "adamW"
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setSolver(value: String): this.type = set(solver, value)
+
+ /**
+ * Set the random seed for weight initialization.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override protected def train(
+ dataset: Dataset[_]
+ ): FMRegressionModel = instrumented { instr =>
+
+ instr.logPipelineStage(this)
+ instr.logDataset(dataset)
+ instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam,
+ miniBatchFraction, initStd, maxIter, stepSize, tol, solver)
+
+ val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
+ instr.logNumFeatures(numFeatures)
+
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
+ val labeledPoint = extractLabeledPoints(dataset)
+ val data: RDD[(Double, OldVector)] =
+ labeledPoint.map(x => (x.label + StaticUtils.ZERO_DOUBLE, x.features))
+
+ if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (coefficients, _) = trainImpl(data, numFeatures, SquaredError)
+
+ val (intercept, linear, factors) = splitCoefficients(
+ coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear))
+
+ if (handlePersistence) data.unpersist()
+
+ copyValues(new FMRegressionModel(uid, intercept, linear, factors))
+ }
+
+ @Since("3.0.0")
+ override def copy(extra: ParamMap): FMRegressor = defaultCopy(extra)
+}
+
+@Since("3.0.0")
+object FMRegressor extends DefaultParamsReadable[FMRegressor] {
+
+ @Since("3.0.0")
+ override def load(path: String): FMRegressor = super.load(path)
+}
+
+/**
+ * Model produced by [[FMRegressor]].
+ */
+@Since("3.0.0")
+class FMRegressionModel private[regression] (
+ @Since("3.0.0") override val uid: String,
+ @Since("3.0.0") val intercept: Double,
+ @Since("3.0.0") val linear: Vector,
+ @Since("3.0.0") val factors: Matrix)
+ extends RegressionModel[Vector, FMRegressionModel]
+ with FMRegressorParams with MLWritable {
+
+ @Since("3.0.0")
+ override val numFeatures: Int = linear.size
+
+ @Since("3.0.0")
+ override def predict(features: Vector): Double = {
+ getRawPrediction(features, intercept, linear, factors)
+ }
+
+ @Since("3.0.0")
+ override def copy(extra: ParamMap): FMRegressionModel = {
+ copyValues(new FMRegressionModel(uid, intercept, linear, factors), extra)
+ }
+
+ @Since("3.0.0")
+ override def write: MLWriter =
+ new FMRegressionModel.FMRegressionModelWriter(this)
+
+ override def toString: String = {
+ s"FMRegressionModel: " +
+ s"uid=${super.toString}, numFeatures=$numFeatures, " +
+ s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}"
+ }
+}
+
+@Since("3.0.0")
+object FMRegressionModel extends MLReadable[FMRegressionModel] {
+
+ @Since("3.0.0")
+ override def read: MLReader[FMRegressionModel] = new FMRegressionModelReader
+
+ @Since("3.0.0")
+ override def load(path: String): FMRegressionModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[FMRegressionModel]] */
+ private[FMRegressionModel] class FMRegressionModelWriter(
+ instance: FMRegressionModel) extends MLWriter with Logging {
+
+ private case class Data(
+ intercept: Double,
+ linear: Vector,
+ factors: Matrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.intercept, instance.linear, instance.factors)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class FMRegressionModelReader extends MLReader[FMRegressionModel] {
+
+ private val className = classOf[FMRegressionModel].getName
+
+ override def load(path: String): FMRegressionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.format("parquet").load(dataPath)
+
+ val Row(intercept: Double, linear: Vector, factors: Matrix) = data
+ .select("intercept", "linear", "factors").head()
+ val model = new FMRegressionModel(metadata.uid, intercept, linear, factors)
+ metadata.getAndSetParams(model)
+ model
+ }
+ }
+}
+
+/**
+ * Factorization Machines base gradient class
+ * Implementing the raw FM formula, include raw prediction and raw gradient,
+ * then inherit the base class to implement special gradient class(like logloss, mse).
+ *
+ * Factorization Machines raw formula:
+ * {{{
+ * y_{fm} = w_0 + \sum\limits^n_{i-1} w_i x_i +
+ * \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j
+ * }}}
+ * the pairwise interactions (2-way term) can be reformulated:
+ * {{{
+ * \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j
+ * = \frac{1}{2}\sum\limits^k_{f=1}
+ * \left(\left( \sum\limits^n_{i=1}v_{i,f}x_i \right)^2 -
+ * \sum\limits^n_{i=1}v_{i,f}^2x_i^2 \right)
+ * }}}
+ * and the gradients are:
+ * {{{
+ * \frac{\partial}{\partial\theta}y_{fm} = \left\{
+ * \begin{align}
+ * &1, & if\ \theta\ is\ w_0 \\
+ * &x_i, & if\ \theta\ is\ w_i \\
+ * &x_i{\sum}^n_{j=1}v_{j,f}x_j - v_{i,f}x_i^2, & if\ \theta\ is\ v_{i,j} \\
+ * \end{align}
+ * \right.
+ * }}}
+ *
+ * Factorization Machines formula with prediction task:
+ * {{{
+ * \hat{y} = p\left( y_{fm} \right)
+ * }}}
+ * p is the prediction function, for binary classification task is sigmoid.
+ * The loss function gradient formula:
+ * {{{
+ * \frac{\partial}{\partial\theta} l\left( \hat{y},y \right) =
+ * \frac{\partial}{\partial\theta} l\left( p\left( y_{fm} \right),y \right) =
+ * \frac{\partial l}{\partial \hat{y}} \cdot
+ * \frac{\partial \hat{y}}{\partial y_{fm}} \cdot
+ * \frac{\partial y_{fm}}{\partial\theta}
+ * }}}
+ * Last term is same for all task, so be implemented in base gradient class.
+ * last term named rawGradient in following code, and first two term named multiplier.
+ */
+private[ml] abstract class BaseFactorizationMachinesGradient(
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean,
+ numFeatures: Int) extends Gradient {
+
+ override def compute(
+ data: OldVector,
+ label: Double,
+ weights: OldVector,
+ cumGradient: OldVector): Double = {
+ val (rawPrediction, sumVX) = getRawPrediction(data, weights)
+ val rawGradient = getRawGradient(data, weights, sumVX)
+ val multiplier = getMultiplier(rawPrediction, label)
+ axpy(multiplier, rawGradient, cumGradient)
+ val loss = getLoss(rawPrediction, label)
+ loss
+ }
+
+ def getPrediction(rawPrediction: Double): Double
+
+ protected def getMultiplier(rawPrediction: Double, label: Double): Double
+
+ protected def getLoss(rawPrediction: Double, label: Double): Double
+
+ def getRawPrediction(data: OldVector, weights: OldVector): (Double, Array[Double]) = {
+ val sumVX = new Array[Double](factorSize)
+ var rawPrediction = 0.0
+ val vWeightsSize = numFeatures * factorSize
+
+ if (fitIntercept) rawPrediction += weights(weights.size - 1)
+ if (fitLinear) {
+ data.foreachNonZero { case (index, value) =>
+ rawPrediction += weights(vWeightsSize + index) * value
+ }
+ }
+ (0 until factorSize).foreach { f =>
+ var sumSquare = 0.0
+ var sum = 0.0
+ data.foreachNonZero { case (index, value) =>
+ val vx = weights(index * factorSize + f) * value
+ sumSquare += vx * vx
+ sum += vx
+ }
+ sumVX(f) = sum
+ rawPrediction += 0.5 * (sum * sum - sumSquare)
+ }
+
+ (rawPrediction, sumVX)
+ }
+
+ private def getRawGradient(
+ data: OldVector,
+ weights: OldVector,
+ sumVX: Array[Double]
+ ): OldVector = {
+ data match {
+ // Usually Factorization Machines is used, there will be a lot of sparse features.
+ // So need to optimize the gradient descent of sparse vector.
+ case data: OldLinalg.SparseVector =>
+ val gardSize = data.indices.length * factorSize +
+ (if (fitLinear) data.indices.length else 0) +
+ (if (fitIntercept) 1 else 0)
+ val gradIndex = Array.ofDim[Int](gardSize)
+ val gradValue = Array.ofDim[Double](gardSize)
+ var gradI = 0
+ val vWeightsSize = numFeatures * factorSize
+
+ data.foreachNonZero { case (index, value) =>
+ (0 until factorSize).foreach { f =>
+ gradIndex(gradI) = index * factorSize + f
+ gradValue(gradI) = value * sumVX(f) - weights(index * factorSize + f) * value * value
+ gradI += 1
+ }
+ }
+ if (fitLinear) {
+ data.foreachNonZero { case (index, value) =>
+ gradIndex(gradI) = vWeightsSize + index
+ gradValue(gradI) = value
+ gradI += 1
+ }
+ }
+ if (fitIntercept) {
+ gradIndex(gradI) = weights.size - 1
+ gradValue(gradI) = 1.0
+ }
+
+ OldVectors.sparse(weights.size, gradIndex, gradValue)
+ case data: OldLinalg.DenseVector =>
+ val gradient = Array.ofDim[Double](weights.size)
+ val vWeightsSize = numFeatures * factorSize
+
+ if (fitIntercept) gradient(weights.size - 1) += 1.0
+ if (fitLinear) {
+ data.foreachNonZero { case (index, value) =>
+ gradient(vWeightsSize + index) += value
+ }
+ }
+ (0 until factorSize).foreach { f =>
+ data.foreachNonZero { case (index, value) =>
+ gradient(index * factorSize + f) +=
+ value * sumVX(f) - weights(index * factorSize + f) * value * value
+ }
+ }
+
+ OldVectors.dense(gradient)
+ }
+ }
+}
+
+/**
+ * FM with logistic loss
+ * prediction formula:
+ * {{{
+ * \hat{y} = \sigmoid(y_{fm})
+ * }}}
+ * loss formula:
+ * {{{
+ * - y * log(\hat{y}) - (1 - y) * log(1 - \hat{y})
+ * }}}
+ * multiplier formula:
+ * {{{
+ * \frac{\partial l}{\partial \hat{y}} \cdot
+ * \frac{\partial \hat{y}}{\partial y_{fm}} =
+ * \hat{y} - y
+ * }}}
+ */
+private[ml] class LogisticFactorizationMachinesGradient(
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean,
+ numFeatures: Int)
+ extends BaseFactorizationMachinesGradient(
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean,
+ numFeatures: Int) with Logging {
+
+ override def getPrediction(rawPrediction: Double): Double = {
+ 1.0 / (1.0 + math.exp(-rawPrediction))
+ }
+
+ override protected def getMultiplier(rawPrediction: Double, label: Double): Double = {
+ getPrediction(rawPrediction) - label
+ }
+
+ override protected def getLoss(rawPrediction: Double, label: Double): Double = {
+ if (label > 0) MLUtils.log1pExp(-rawPrediction)
+ else MLUtils.log1pExp(rawPrediction)
+ }
+}
+
+/**
+ * FM with mse
+ * prediction formula:
+ * {{{
+ * \hat{y} = y_{fm}
+ * }}}
+ * loss formula:
+ * {{{
+ * (\hat{y} - y) ^ 2
+ * }}}
+ * multiplier formula:
+ * {{{
+ * \frac{\partial l}{\partial \hat{y}} \cdot
+ * \frac{\partial \hat{y}}{\partial y_{fm}} =
+ * 2 * (\hat{y} - y)
+ * }}}
+ */
+private[ml] class MSEFactorizationMachinesGradient(
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean,
+ numFeatures: Int)
+ extends BaseFactorizationMachinesGradient(
+ factorSize: Int,
+ fitIntercept: Boolean,
+ fitLinear: Boolean,
+ numFeatures: Int) with Logging {
+
+ override def getPrediction(rawPrediction: Double): Double = {
+ rawPrediction
+ }
+
+ override protected def getMultiplier(rawPrediction: Double, label: Double): Double = {
+ 2 * (rawPrediction - label)
+ }
+
+ override protected def getLoss(rawPrediction: Double, label: Double): Double = {
+ (rawPrediction - label) * (rawPrediction - label)
+ }
+}
+
+/**
+ * AdamW optimizer.
+ *
+ * The implementation is based upon:
+ *
+ * Loshchilov I, Hutter F. "DECOUPLED WEIGHT DECAY REGULARIZATION" 2019.
+ *
+ * The main contribution of this paper is to improve regularization in Adam
+ * by decoupling the weight decay from the gradient-based update.
+ * This paper proposed a simple modification to recover the original formulation of
+ * weight decay regularization by decoupling the weight decay from the optimization steps
+ * taken w.r.t. the loss function.
+ */
+private[ml] class AdamWUpdater(weightSize: Int) extends Updater with Logging {
+ val beta1: Double = 0.9
+ val beta2: Double = 0.999
+ val epsilon: Double = 1e-8
+
+ val m: BV[Double] = BV.zeros[Double](weightSize).toDenseVector
+ val v: BV[Double] = BV.zeros[Double](weightSize).toDenseVector
+ var beta1T: Double = 1.0
+ var beta2T: Double = 1.0
+
+ override def compute(
+ weightsOld: OldVector,
+ gradient: OldVector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double
+ ): (OldVector, Double) = {
+ val w: BV[Double] = weightsOld.asBreeze.toDenseVector
+ val lr = stepSize // learning rate
+ if (stepSize > 0) {
+ val g: BV[Double] = gradient.asBreeze.toDenseVector
+ m *= beta1
+ brzAxpy(1 - beta1, g, m)
+ v *= beta2
+ brzAxpy(1 - beta2, g * g, v)
+ beta1T *= beta1
+ beta2T *= beta2
+ val mHat = m / (1 - beta1T)
+ val vHat = v / (1 - beta2T)
+ w -= lr * mHat / (brzSqrt(vHat) + epsilon) + regParam * w
+ }
+ val norm = brzNorm(w, 2.0)
+
+ (Vectors.fromBreeze(w), 0.5 * regParam * norm * norm)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 15725d31a2452ae4ab76266e65eeae11324b298b..f3cd8a0812b0bdb6fbdde438262976b6279a0cf5 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -17,13 +17,12 @@
package org.apache.spark.ml.regression
-import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{BLAS, Vector}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
@@ -310,7 +309,7 @@ class GBTRegressionModel private[ml](
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
- blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
+ BLAS.nativeBLAS.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
}
@Since("1.4.0")
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 67b9166a0f44db3e8a8b78461698f4bb6d5a9251..cc917db98b3288dd13a4b290ef42533604360383 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -383,6 +383,15 @@ private[ml] object DecisionTreeModelReadWrite {
node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.emptyDoubleArray, -1))),
id)
}
+
+ /**
+ * When save a tree model, infer the number of partitions based on number of nodes.
+ */
+ def inferNumPartitions(numNodes: Long): Int = {
+ require(numNodes > 0)
+ // 7,280,000 nodes is about 128MB
+ (numNodes / 7280000.0).ceil.toInt
+ }
}
/**
@@ -404,8 +413,8 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
var df = sparkSession.read.parquet(dataPath)
- val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
- if (major.toInt < 3) {
+ val (major, _) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ if (major < 3) {
df = df.withColumn("rawCount", lit(-1L))
}
@@ -459,23 +468,27 @@ private[ml] object EnsembleModelReadWrite {
def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
instance: M,
path: String,
- sql: SparkSession,
+ sparkSession: SparkSession,
extraMetadata: JObject): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
- val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
- case (tree, treeID) =>
- (treeID,
- DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext),
- instance.treeWeights(treeID))
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession.sparkContext, Some(extraMetadata))
+ val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree, treeID) =>
+ (treeID,
+ DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession.sparkContext),
+ instance.treeWeights(treeID))
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
- sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights")
+ sparkSession.createDataFrame(treesMetadataWeights)
+ .toDF("treeID", "metadata", "weights")
+ .repartition(1)
.write.parquet(treesMetadataPath)
+
val dataPath = new Path(path, "data").toString
- val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
- case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
- }
- sql.createDataFrame(nodeDataRDD).write.parquet(dataPath)
+ val numDataParts = NodeData.inferNumPartitions(instance.trees.map(_.numNodes.toLong).sum)
+ val nodeDataRDD = sparkSession.sparkContext.parallelize(instance.trees.zipWithIndex)
+ .flatMap { case (tree, treeID) => EnsembleNodeData.build(tree, treeID) }
+ sparkSession.createDataFrame(nodeDataRDD)
+ .repartition(numDataParts)
+ .write.parquet(dataPath)
}
/**
@@ -490,12 +503,12 @@ private[ml] object EnsembleModelReadWrite {
*/
def loadImpl(
path: String,
- sql: SparkSession,
+ sparkSession: SparkSession,
className: String,
treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
- import sql.implicits._
+ import sparkSession.implicits._
implicit val format = DefaultFormats
- val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession.sparkContext, className)
// Get impurity to construct ImpurityCalculator for each node
val impurityType: String = {
@@ -504,7 +517,7 @@ private[ml] object EnsembleModelReadWrite {
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
- val treesMetadataRDD = sql.read.parquet(treesMetadataPath)
+ val treesMetadataRDD = sparkSession.read.parquet(treesMetadataPath)
.select("treeID", "metadata", "weights")
.as[(Int, String, Double)].rdd
.map { case (treeID: Int, json: String, weights: Double) =>
@@ -516,9 +529,9 @@ private[ml] object EnsembleModelReadWrite {
val treesWeights = treesMetadataWeights.map(_._2)
val dataPath = new Path(path, "data").toString
- var df = sql.read.parquet(dataPath)
- val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
- if (major.toInt < 3) {
+ var df = sparkSession.read.parquet(dataPath)
+ val (major, _) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ if (major < 3) {
val newNodeDataCol = df.schema("nodeData").dataType match {
case StructType(fields) =>
val cols = fields.map(f => col(s"nodeData.${f.name}")) :+ lit(-1L).as("rawCount")
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index acb843a69aab2eb51f6f1887ce29c73fef96f031..f8eabaa8c8b0ad1aa09840e83d661818c648cff5 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -60,8 +60,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
*/
final val maxDepth: IntParam =
new IntParam(this, "maxDepth", "Maximum depth of the tree. (Nonnegative)" +
- " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
- ParamValidators.gtEq(0))
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes." +
+ " Must be in range [0, 30].",
+ ParamValidators.inRange(0, 30))
/**
* Maximum number of bins used for discretizing continuous features and for choosing how to split
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidator.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e78de431f3e0facfe72840ba3e7834284b3d3eab
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidator.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasParallelism
+import org.apache.spark.ml.util.{DefaultParamsReader, Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Bayesian Hyperparameter Optimization for K-fold cross validation
+ */
+class BayesianCrossValidator(override val uid: String)
+ extends Estimator[BayesianCrossValidatorModel]
+ with BayesianCrossValidatorParams with HasParallelism {
+
+ def this() = this(Identifiable.randomUID("bayesCv"))
+
+ /**
+ * TransformSchema.
+ * @param schema structType before
+ * @return structType after
+ */
+ override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
+
+ /** @group setParam */
+ def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+
+ /** @group setParam */
+ def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+
+ /** @group setParam */
+ def setNumFolds(value: Int): this.type = set(numFolds, value)
+
+ /** @group setParam */
+ def setNumIterations(value: Int): this.type = set(numIterations, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group setParam */
+ def setEstimatorParamSpace(value: ParamSpace): this.type = set(estimatorParamSpace, value)
+
+ /** @group setParam */
+ def setParallelism(value: Int): this.type = set(parallelism, value)
+
+ /** @group setParam */
+ def setThreshold(value: Double): this.type = {
+ set(threshold, value)
+ set(thresholdFlag, true)
+ }
+
+ /** @group setParam */
+ def setThresholdFlag(value: Boolean): this.type = set(thresholdFlag, value)
+
+ private def isStop(metric: Double): Boolean = {
+ if ($(evaluator).isLargerBetter && metric >= $(threshold)) {
+ true
+ } else if (!$(evaluator).isLargerBetter && metric <= $(threshold)) {
+ true
+ } else {
+ false
+ }
+ }
+
+ var searchNumber: Int = 0
+ def getSearchNumber: Int = searchNumber
+
+ var bestMetric: Double = 0.0
+ def getBestMetric: Double = bestMetric
+
+ /**
+ * Fit the given dataset with BayesianCrossValidator.
+ *
+ * @param dataset dataset to be fitted
+ * @return BayesianCrossValidatorModel
+ */
+ override def fit(dataset: Dataset[_]): BayesianCrossValidatorModel = {
+ val sqlContext = dataset.sqlContext
+ val schema = dataset.schema
+ transformSchema(schema)
+ val est = $(estimator)
+ val eval = $(evaluator)
+ val folds = $(numFolds)
+ val iterations = $(numIterations)
+ val paramSpace = $(estimatorParamSpace)
+ val executionContext = getExecutionContext
+
+ val solver = new Solver(sqlContext.sparkSession, paramSpace, !eval.isLargerBetter,
+ BayesianCrossValidator.BATCH_SIZE, BayesianCrossValidator.SAMPLE_SIZE)
+ val splits = MLUtils.kFold(dataset.toDF.rdd, folds, $(seed)).map {
+ case (training, validation) =>
+ val trainingDataset = sqlContext.createDataFrame(training, schema).persist()
+ val validationDataset = sqlContext.createDataFrame(validation, schema).persist()
+ (trainingDataset, validationDataset)
+ }
+
+ var stop = false
+ val observations = for {
+ iter <- (1 to iterations) if !stop
+ } yield {
+ val config: ParamMap = solver.suggest()(0)
+ val accMetricsFuture = splits.map { case (training, validation) => Future[Double] {
+ val models = est.fit(training, config)
+ eval.evaluate((models.asInstanceOf[Model[_]]).transform(validation, config))
+ }(executionContext)
+ }
+ val accMetrics = accMetricsFuture.map(ThreadUtils.awaitResult(_, Duration.Inf))
+ val avgMetric: Double = (accMetrics).sum / (accMetrics).length
+ logInfo(s"Iteration $iter: $avgMetric")
+ solver.feed(config, avgMetric)
+ if ($(thresholdFlag)) {
+ stop = isStop(avgMetric)
+ }
+ (config, avgMetric)
+ }
+
+ splits.foreach {
+ case (training, validation) =>
+ training.unpersist()
+ validation.unpersist()
+ }
+
+ val bestObservation =
+ if (eval.isLargerBetter) observations.maxBy(_._2) else observations.minBy(_._2)
+ val bestParams = bestObservation._1
+ searchNumber = observations.length
+ bestMetric = bestObservation._2
+
+ logInfo(s"Best set of parameters:\n$bestParams")
+ logInfo(s"Best cross-validation metric: $bestMetric.")
+ val bestModel = (est.fit(dataset, bestParams)).asInstanceOf[Model[_]]
+ copyValues(
+ new BayesianCrossValidatorModel(uid, bestModel).setParent(this))
+ }
+
+ /**
+ * Copy a BayesianCrossValidator instance.
+ * @param extra extra ParamMap
+ * @return BayesianCrossValidator
+ */
+ override def copy(extra: ParamMap): BayesianCrossValidator = {
+ val copied = defaultCopy(extra).asInstanceOf[BayesianCrossValidator]
+ if (copied.isDefined(estimator)) {
+ copied.setEstimator(copied.getEstimator.copy(extra))
+ }
+ if (copied.isDefined(evaluator)) {
+ copied.setEvaluator(copied.getEvaluator.copy(extra))
+ }
+ copied
+ }
+}
+
+object BayesianCrossValidator {
+ private val BATCH_SIZE: Int = 1
+ private val SAMPLE_SIZE: Int = 10000
+}
+
+/**
+ * BayesianCrossValidatorModel contains the bestModel.
+ * @param uid uid of BayesianCrossValidatorModel
+ * @param bestModel optimal metrics Model
+ */
+class BayesianCrossValidatorModel private[ml](
+ override val uid: String,
+ val bestModel: Model[_])
+ extends Model[BayesianCrossValidatorModel] with BayesianCrossValidatorParams with MLWritable {
+
+ /**
+ * Write method.
+ * @return MLWriter
+ */
+ override def write: MLWriter = new BayesianCrossValidatorModel
+ .BayesianCrossValidatorModelWriter(this)
+
+ /**
+ * Copy a BayesianCrossValidatorModel instance.
+ * @param extra extra ParamMap
+ * @return BayesianCrossValidatorModel
+ */
+ override def copy(extra: ParamMap): BayesianCrossValidatorModel = {
+ val copied = new BayesianCrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ /**
+ * Transform the dataset with best model.
+ * @param dataset dataset to be transformed
+ * @return transformed dataset
+ */
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ bestModel.transform(dataset)
+ }
+
+ /**
+ * TransformSchema
+ * @param schema structType before
+ * @return structType after
+ */
+ override def transformSchema(schema: StructType): StructType = {
+ bestModel.transformSchema(schema)
+ }
+}
+
+object BayesianCrossValidatorModel extends MLReadable[BayesianCrossValidatorModel] {
+ /**
+ * Read method
+ * @return MLReader[BayesianCrossValidatorModel]
+ */
+ override def read: MLReader[BayesianCrossValidatorModel] = new BayesianCrossValidatorModelReader
+
+ /**
+ * load method
+ * @param path file path
+ * @return BayesianCrossValidatorModel
+ */
+ override def load(path: String): BayesianCrossValidatorModel = super.load(path)
+
+ private[BayesianCrossValidatorModel]
+ class BayesianCrossValidatorModelWriter(instance: BayesianCrossValidatorModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ BayesianCrossValidatorParams.saveImpl(path, instance, sc)
+ val bestModelPath = new Path(path, "bestModel").toString
+ instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ }
+ }
+
+ private class BayesianCrossValidatorModelReader extends MLReader[BayesianCrossValidatorModel] {
+
+ private val className = classOf[BayesianCrossValidatorModel].getName
+
+ override def load(path: String): BayesianCrossValidatorModel = {
+ implicit val format: DefaultFormats.type = DefaultFormats
+
+ val (metadata, estimator, evaluator) = BayesianCrossValidatorParams
+ .loadImpl(path, sc, className)
+ val bestModelPath = new Path(path, "bestModel").toString
+ val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+ val model = new BayesianCrossValidatorModel(metadata.uid, bestModel)
+ model.set(model.estimator, estimator)
+ .set(model.evaluator, evaluator)
+ .set(model.numFolds, (metadata.params \ "numFolds").extract[Int])
+ .set(model.numIterations, (metadata.params \ "numIterations").extract[Int])
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidatorParams.scala b/ml-accelerator/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidatorParams.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a90b7dba07a14002baf4fbbeb19b245af4f61a14
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidatorParams.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL.list2jvalue
+import org.json4s.jackson.JsonMethods.{parse, render}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.param.{BooleanParam, DoubleParam, IntParam, Param, ParamPair, Params, ParamValidators}
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, Instrumentation, MLWritable}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Params for BayesianCrossValidator.
+ */
+private[ml] trait BayesianCrossValidatorParams extends HasSeed with Params {
+
+ /**
+ * param for the estimator to be validated
+ *
+ * @group param
+ */
+ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+
+ /** @group getParam */
+ def getEstimator: Estimator[_] = $(estimator)
+
+ /**
+ * param for estimator param maps
+ *
+ * @group param
+ */
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+ "evaluator used to select hyper-parameters that maximize the validated metric")
+
+ /** @group getParam */
+ def getEvaluator: Evaluator = $(evaluator)
+
+ /**
+ * param for folds number
+ *
+ * @group param
+ */
+ val numFolds: IntParam = new IntParam(
+ this,
+ "numFolds",
+ "number of folds for cross validation (>= 2)",
+ ParamValidators.gtEq(2)
+ )
+
+ /** @group getParam */
+ def getNumFolds: Int = $(numFolds)
+
+ setDefault(numFolds -> 3)
+
+ /**
+ * param for iterations number
+ *
+ * @group param
+ */
+ val numIterations: IntParam = new IntParam(
+ this,
+ "numIterations",
+ "number of cross validations to run (>= 2)",
+ ParamValidators.gtEq(2)
+ )
+
+ /** @group getParam */
+ def getNumIterations: Int = $(numIterations)
+
+ setDefault(numIterations -> 10)
+
+ /**
+ * param for the estimator hyper-parameter space to be searched
+ *
+ * @group param
+ */
+ val estimatorParamSpace: Param[ParamSpace] =
+ new Param(this, "estimatorParamSpace", "hyper-parameter space for the estimator")
+
+ /** @group getParam */
+ def getEstimatorParamSpace: ParamSpace = $(estimatorParamSpace)
+
+ /**
+ * param for the threshold to stop search
+ *
+ * @group param
+ */
+ val threshold = new DoubleParam(this, "threshold",
+ "threshold to stop working")
+
+ /** @group getParam */
+ def getThreshold: Double = $(threshold)
+ setDefault(threshold -> 0.0)
+
+ /**
+ * param for the threshold judgment flag
+ *
+ * @group param
+ */
+ val thresholdFlag = new BooleanParam(this, "thresholdFlag",
+ "flag for threshold judgment")
+
+ /** @group getParam */
+ def getThresholdFlag: Boolean = $(thresholdFlag)
+
+ setDefault(thresholdFlag -> false)
+
+ protected def transformSchemaImpl(schema: StructType): StructType = {
+ $(estimator).transformSchema(schema)
+ }
+}
+
+private object BayesianCrossValidatorParams {
+
+ /**
+ * save method.
+ *
+ * @param path file path
+ * @param instance BayesianCrossValidatorParams
+ * @param sc SparkContext
+ * @param extraMetadata extra metadata
+ */
+ def saveImpl(
+ path: String,
+ instance: BayesianCrossValidatorParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+
+ val params = instance.extractParamMap().toSeq
+ val skipParams = List("estimator", "evaluator", "estimatorParamSpace")
+ val jsonParams = render(params
+ .filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
+ .map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList)
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val evaluatorPath = new Path(path, "evaluator").toString
+ instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+
+ val estimatorPath = new Path(path, "estimator").toString
+ instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+ }
+
+ /**
+ * load method.
+ *
+ * @param path file path
+ * @param sc SparkContext
+ * @param expectedClassName expected class name
+ */
+ def loadImpl[M <: Model[M]](
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (Metadata, Estimator[M], Evaluator) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val evaluatorPath = new Path(path, "evaluator").toString
+ val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+ val estimatorPath = new Path(path, "estimator").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+ (metadata, estimator, evaluator)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 42fd9deebe88f4d9e15e040cec2a9b09c5055d6b..13a5f5490b10e1f2044cc8664e56823355d3e561 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -326,12 +326,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
timer.stop("docs-mapPartitions")
- val elementWiseSum = (
+ def elementWiseSum(
u: (BDM[Double], Option[BDV[Double]], Long),
- v: (BDM[Double], Option[BDV[Double]], Long)) => {
- u._1 += v._1
+ v: (BDM[Double], Option[BDV[Double]], Long)): (BDM[Double], Option[BDV[Double]], Long) = {
+ val vec =
+ if (u._1 == null) {
+ v._1
+ } else if (v._1 == null) {
+ u._1
+ } else {
+ u._1 += v._1
+ u._1
+ }
u._2.foreach(_ += v._2.get)
- (u._1, u._2, u._3 + v._3)
+ (vec, u._2, u._3 + v._3)
}
timer.start("stats-treeAggregate")
@@ -340,7 +348,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
LDAUtilsXOpt.optimizedAggregateStats(stats)
} else {
stats
- .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))(
+ .treeAggregate((null.asInstanceOf[BDM[Double]], logphatPartOptionBase(), 0L))(
elementWiseSum, elementWiseSum
)
}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..13dd022fd1b39717c391d45eb62bea85c344445d
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.feature
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.rdd.RDD
+
+/**
+ * Inverse document frequency (IDF).
+ * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
+ * number of documents and `d(t)` is the number of documents that contain term `t`.
+ *
+ * This implementation supports filtering out terms which do not appear in a minimum number
+ * of documents (controlled by the variable `minDocFreq`). For terms that are not in
+ * at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0.
+ * The document frequency is 0 as well for such terms
+ *
+ * @param minDocFreq minimum of documents in which a term
+ * should appear for filtering
+ */
+@Since("1.1.0")
+class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) {
+
+ @Since("1.1.0")
+ def this() = this(0)
+
+ // TODO: Allow different IDF formulations.
+
+ /**
+ * Computes the inverse document frequency.
+ * @param dataset an RDD of term frequency vectors
+ */
+ @Since("1.1.0")
+ def fit(dataset: RDD[Vector]): IDFModel = {
+ val (idf, docFreq, numDocs) = IDFUtils.train(dataset, minDocFreq)
+ new IDFModel(idf, docFreq, numDocs)
+ }
+
+ /**
+ * Computes the inverse document frequency.
+ * @param dataset a JavaRDD of term frequency vectors
+ */
+ @Since("1.1.0")
+ def fit(dataset: JavaRDD[Vector]): IDFModel = {
+ fit(dataset.rdd)
+ }
+}
+
+private object IDF {
+
+ /** Document frequency aggregator. */
+ class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable {
+
+ /** number of documents */
+ private var m = 0L
+ /** document frequency vector */
+ private var df: BDV[Long] = _
+
+
+ def this() = this(0)
+
+ /** Adds a new document. */
+ def add(doc: Vector): this.type = {
+ if (isEmpty) {
+ df = BDV.zeros(doc.size)
+ }
+ doc match {
+ case SparseVector(size, indices, values) =>
+ val nnz = indices.length
+ var k = 0
+ while (k < nnz) {
+ if (values(k) > 0) {
+ df(indices(k)) += 1L
+ }
+ k += 1
+ }
+ case DenseVector(values) =>
+ val n = values.length
+ var j = 0
+ while (j < n) {
+ if (values(j) > 0.0) {
+ df(j) += 1L
+ }
+ j += 1
+ }
+ case other =>
+ throw new UnsupportedOperationException(
+ s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+ }
+ m += 1L
+ this
+ }
+
+ /** Merges another. */
+ def merge(other: DocumentFrequencyAggregator): this.type = {
+ if (!other.isEmpty) {
+ m += other.m
+ if (df == null) {
+ df = other.df.copy
+ } else {
+ df += other.df
+ }
+ }
+ this
+ }
+
+ private def isEmpty: Boolean = m == 0L
+
+ /** Returns the current IDF vector, docFreq, number of documents */
+ def idf(): (Vector, Array[Long], Long) = {
+ if (isEmpty) {
+ throw new IllegalStateException("Haven't seen any document yet.")
+ }
+ val n = df.length
+ val inv = new Array[Double](n)
+ val dfv = new Array[Long](n)
+ var j = 0
+ while (j < n) {
+ /*
+ * If the term is not present in the minimum
+ * number of documents, set IDF to 0. This
+ * will cause multiplication in IDFModel to
+ * set TF-IDF to 0.
+ *
+ * Since arrays are initialized to 0 by default,
+ * we just omit changing those entries.
+ */
+ if (df(j) >= minDocFreq) {
+ inv(j) = math.log((m + 1.0) / (df(j) + 1.0))
+ dfv(j) = df(j)
+ }
+ j += 1
+ }
+ (Vectors.dense(inv), dfv, m)
+ }
+ }
+}
+
+/**
+ * Represents an IDF model that can transform term frequency vectors.
+ */
+@Since("1.1.0")
+class IDFModel private[spark](@Since("1.1.0") val idf: Vector,
+ @Since("3.0.0") val docFreq: Array[Long],
+ @Since("3.0.0") val numDocs: Long) extends Serializable {
+
+ /**
+ * Transforms term frequency (TF) vectors to TF-IDF vectors.
+ *
+ * If `minDocFreq` was set for the IDF calculation,
+ * the terms which occur in fewer than `minDocFreq`
+ * documents will have an entry of 0.
+ *
+ * @param dataset an RDD of term frequency vectors
+ * @return an RDD of TF-IDF vectors
+ */
+ @Since("1.1.0")
+ def transform(dataset: RDD[Vector]): RDD[Vector] = {
+ val bcIdf = dataset.context.broadcast(idf)
+ dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v)))
+ }
+
+ /**
+ * Transforms a term frequency (TF) vector to a TF-IDF vector
+ *
+ * @param v a term frequency vector
+ * @return a TF-IDF vector
+ */
+ @Since("1.3.0")
+ def transform(v: Vector): Vector = IDFModel.transform(idf, v)
+
+ /**
+ * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
+ * @param dataset a JavaRDD of term frequency vectors
+ * @return a JavaRDD of TF-IDF vectors
+ */
+ @Since("1.1.0")
+ def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
+ transform(dataset.rdd).toJavaRDD()
+ }
+}
+
+private[spark] object IDFModel {
+
+ /**
+ * Transforms a term frequency (TF) vector to a TF-IDF vector with a IDF vector
+ *
+ * @param idf an IDF vector
+ * @param v a term frequency vector
+ * @return a TF-IDF vector
+ */
+ def transform(idf: Vector, v: Vector): Vector = {
+ v match {
+ case SparseVector(size, indices, values) =>
+ val (newIndices, newValues) = transformSparse(idf, indices, values)
+ Vectors.sparse(size, newIndices, newValues)
+ case DenseVector(values) =>
+ val newValues = transformDense(idf, values)
+ Vectors.dense(newValues)
+ case other =>
+ throw new UnsupportedOperationException(
+ s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+ }
+ }
+
+ private[spark] def transformDense(
+ idf: Vector,
+ values: Array[Double]): Array[Double] = {
+ val n = values.length
+ val newValues = new Array[Double](n)
+ var j = 0
+ while (j < n) {
+ newValues(j) = values(j) * idf(j)
+ j += 1
+ }
+ newValues
+ }
+
+ private[spark] def transformSparse(
+ idf: Vector,
+ indices: Array[Int],
+ values: Array[Double]): (Array[Int], Array[Double]) = {
+ val nnz = indices.length
+ val newValues = new Array[Double](nnz)
+ var k = 0
+ while (k < nnz) {
+ newValues(k) = values(k) * idf(indices(k))
+ k += 1
+ }
+ (indices, newValues)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
new file mode 100644
index 0000000000000000000000000000000000000000..255d73021bf50220d4445f5ef69c88463059cffa
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.feature
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.linalg.distributed.RowMatrix
+import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.rdd.RDD
+
+/**
+ * A feature transformer that projects vectors to a low-dimensional space using PCA.
+ *
+ * @param k number of principal components
+ */
+@Since("1.4.0")
+class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
+ require(k > 0,
+ s"Number of principal components must be positive but got ${k}")
+
+ /**
+ * Computes a [[PCAModel]] that contains the principal components of the input vectors.
+ *
+ * @param sources source vectors
+ */
+ @Since("1.4.0")
+ def fit(sources: RDD[Vector]): PCAModel = {
+ val numFeatures = sources.first().size
+ require(k <= numFeatures,
+ s"source vector size $numFeatures must be no less than k=$k")
+
+ val mat = new RowMatrix(sources)
+
+ val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
+ val densePC = pc match {
+ case dm: DenseMatrix =>
+ dm
+ case sm: SparseMatrix =>
+ /* Convert a sparse matrix to dense.
+ *
+ * RowMatrix.computePrincipalComponents always returns a dense matrix.
+ * The following code is a safeguard.
+ */
+ sm.toDense
+ case m =>
+ throw new IllegalArgumentException("Unsupported matrix format. Expected " +
+ s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")
+ }
+ val denseExplainedVariance = explainedVariance match {
+ case dv: DenseVector =>
+ dv
+ case sv: SparseVector =>
+ sv.toDense
+ }
+ new PCAModel(k, densePC, denseExplainedVariance)
+ }
+
+ /**
+ * Java-friendly version of `fit()`.
+ */
+ @Since("1.4.0")
+ def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd)
+}
+
+/**
+ * Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA.
+ *
+ * @param k number of principal components.
+ * @param pc a principal components Matrix. Each column is one principal component.
+ */
+@Since("1.4.0")
+class PCAModel private[spark] (
+ @Since("1.4.0") val k: Int,
+ @Since("1.4.0") val pc: DenseMatrix,
+ @Since("1.6.0") val explainedVariance: DenseVector) extends VectorTransformer {
+ /**
+ * Transform a vector by computed Principal Components.
+ *
+ * @param vector vector to be transformed.
+ * Vector must be the same length as the source vectors given to `PCA.fit()`.
+ * @return transformed vector. Vector will be of length k.
+ */
+ @Since("1.4.0")
+ override def transform(vector: Vector): Vector = {
+ pc.transpose.multiply(vector)
+ }
+}
+
+private[feature] object PCAUtil {
+
+ // This memory cost formula is from breeze code:
+ // https://github.com/scalanlp/breeze/blob/
+ // 6e541be066d547a097f5089165cd7c38c3ca276d/math/src/main/scala/breeze/linalg/
+ // functions/svd.scala#L87
+ def memoryCost(k: Int, numFeatures: Int): Long = {
+ 3L * math.min(k, numFeatures) * math.min(k, numFeatures) +
+ math.max(math.max(k, numFeatures), 4L * math.min(k, numFeatures) *
+ math.min(k, numFeatures) + 4L * math.min(k, numFeatures))
+ }
+
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
new file mode 100644
index 0000000000000000000000000000000000000000..631bc95c7e29da82250f120244ca557c5fd5e9a8
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import java.lang.{Iterable => JavaIterable}
+
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
+import org.apache.spark.{HashPartitioner, SparkContext}
+import org.apache.spark.annotation.Since
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.fpm.FPGrowth._
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Model trained by [[FPGrowth]], which holds frequent itemsets.
+ * @param freqItemsets frequent itemset, which is an RDD of `FreqItemset`
+ * @tparam Item item type
+ */
+@Since("1.3.0")
+class FPGrowthModel[Item: ClassTag] @Since("2.4.0")(
+ @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]],
+ @Since("2.4.0") val itemSupport: Map[Item, Double])
+ extends Saveable with Serializable {
+
+ @Since("1.3.0")
+ def this(freqItemsets: RDD[FreqItemset[Item]]) = this(freqItemsets, Map.empty)
+
+ /**
+ * Generates association rules for the `Item`s in [[freqItemsets]].
+ * @param confidence minimal confidence of the rules produced
+ */
+ @Since("1.5.0")
+ def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
+ val associationRules = new AssociationRules(confidence)
+ associationRules.run(freqItemsets, itemSupport)
+ }
+
+ /**
+ * Save this model to the given path.
+ * It only works for Item datatypes supported by DataFrames.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using `FPGrowthModel.load`.
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * If the directory already exists, this method throws an exception.
+ */
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ FPGrowthModel.SaveLoadV1_0.save(this, path)
+ }
+}
+
+@Since("2.0.0")
+object FPGrowthModel extends Loader[FPGrowthModel[_]] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
+ FPGrowthModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[fpm] object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel"
+
+ def save(model: FPGrowthModel[_], path: String): Unit = {
+ val sc = model.freqItemsets.sparkContext
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Get the type of item class
+ val sample = model.freqItemsets.first().items(0)
+ val className = sample.getClass.getCanonicalName
+ val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
+ val tpe = classSymbol.selfType
+
+ val itemType = ScalaReflection.schemaFor(tpe).dataType
+ val fields = Array(StructField("items", ArrayType(itemType)),
+ StructField("freq", LongType))
+ val schema = StructType(fields)
+ val rowDataRDD = model.freqItemsets.map { x =>
+ Row(x.items.toSeq, x.freq)
+ }
+ spark.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
+ implicit val formats = DefaultFormats
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val freqItemsets = spark.read.parquet(Loader.dataPath(path))
+ val sample = freqItemsets.select("items").head().get(0)
+ loadImpl(freqItemsets, sample)
+ }
+
+ def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
+ val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x =>
+ val items = x.getAs[Seq[Item]](0).toArray
+ val freq = x.getLong(1)
+ new FreqItemset(items, freq)
+ }
+ new FPGrowthModel(freqItemsetsRDD)
+ }
+ }
+}
+
+/**
+ * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in
+ * Li et al., PFP: Parallel FP-Growth for Query
+ * Recommendation. PFP distributes computation in such a way that each worker executes an
+ * independent group of mining tasks. The FP-Growth algorithm is described in
+ * Han et al., Mining frequent patterns without
+ * candidate generation.
+ *
+ * @param minSupport the minimal support level of the frequent pattern, any pattern that appears
+ * more than (minSupport * size-of-the-dataset) times will be output
+ * @param numPartitions number of partitions used by parallel FP-growth
+ *
+ * @see
+ * Association rule learning (Wikipedia)
+ *
+ */
+@Since("1.3.0")
+class FPGrowth private[spark] (
+ private var minSupport: Double,
+ private var optLevel: Int,
+ private var timeLimit1: Double,
+ private var timeLimit2: Double,
+ private var numPartitions: Int) extends Logging with Serializable {
+
+ /**
+ * Constructs a default instance with default parameters {minSupport: `0.3`, optLevel: '1',
+ * timeLimit1: '0.3', timeLimit2: '0.15', numPartitions: same
+ * as the input data}.
+ *
+ */
+ @Since("1.3.0")
+ def this() = this(0.3, 1, 0.3, 0.15, -1)
+
+ /**
+ * Sets the minimal support level (default: `0.3`).
+ *
+ */
+ @Since("1.3.0")
+ def setMinSupport(minSupport: Double): this.type = {
+ require(minSupport >= 0.0 && minSupport <= 1.0,
+ s"Minimal support level must be in range [0, 1] but got ${minSupport}")
+ this.minSupport = minSupport
+ this
+ }
+
+ private def setOptLevel(optLevel: Int): this.type = {
+ require(optLevel >= 0 && optLevel <= 2,
+ s"optLevel must be [0, 1, 2] but got $optLevel")
+ this.optLevel = optLevel
+ this
+ }
+
+ private def setTimeLimit1(timeLimit: Double): this.type = {
+ require(timeLimit > 0.0 && timeLimit <= Double.MaxValue,
+ s"Time limit must be greater than 0.0 but got $timeLimit")
+ this.timeLimit1 = timeLimit
+ this
+ }
+
+ private def setTimeLimit2(timeLimit: Double): this.type = {
+ require(timeLimit > 0.0 && timeLimit <= Double.MaxValue,
+ s"Time limit must be greater than 0.0 but got $timeLimit")
+ this.timeLimit2 = timeLimit
+ this
+ }
+
+ /**
+ * Sets the number of partitions used by parallel FP-growth (default: same as input data).
+ *
+ */
+ @Since("1.3.0")
+ def setNumPartitions(numPartitions: Int): this.type = {
+ require(numPartitions > 0,
+ s"Number of partitions must be positive but got ${numPartitions}")
+ this.numPartitions = numPartitions
+ this
+ }
+
+ /**
+ * Computes an FP-Growth model that contains frequent itemsets.
+ *
+ * @param data input data set, each element contains a transaction
+ * @return an [[FPGrowthModel]]
+ *
+ */
+ @Since("1.3.0")
+ def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
+ if (data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("Input data is not cached.")
+ }
+ val count = data.count()
+ val minCount = math.ceil(minSupport * count).toLong
+ val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
+ val partitioner = new HashPartitioner(numParts)
+ val freqItemsCount = FPGrowthCore.genFreqItems(data, minCount, partitioner)
+ val itemSupport = freqItemsCount.map {
+ case (item, cnt) => item -> cnt.toDouble / count
+ }.toMap
+
+ setOptLevelFromSparkConf(data.sparkContext)
+ if (optLevel == 0) {
+ val freqItemsets = FPGrowthCore.genFreqItemsets(data, minCount,
+ freqItemsCount.map(_._1), partitioner)
+ new FPGrowthModel(freqItemsets, itemSupport)
+ } else if (optLevel == 1) {
+ setLevel1TimeLimitFromSparkConf(data.sparkContext)
+ val freqItemsets = FPGrowthUtils.genFreqItemsetsByOptLevel1(data, minCount,
+ freqItemsCount.map(_._1), partitioner, timeLimit1)
+ new FPGrowthModel(freqItemsets, itemSupport)
+ } else {
+ setLevel2TimeLimitFromSparkConf(data.sparkContext)
+ val freqItemsets = FPGrowthUtils.genFreqItemsetsByOptLevel2(data, minCount,
+ freqItemsCount.map(_._1), partitioner, timeLimit1, timeLimit2)
+ new FPGrowthModel(freqItemsets, itemSupport)
+ }
+ }
+
+ private val optLevelParamKey = "spark.boostkit.ml.fpgrowth.optLevel"
+ private val timeLimit1ParamKey = "spark.boostkit.ml.fpgrowth.timeLimit1"
+ private val timeLimit2ParamKey = "spark.boostkit.ml.fpgrowth.timeLimit2"
+
+ private def setOptLevelFromSparkConf(sc: SparkContext): Unit = {
+ val optLevelStr = sc.conf.getOption(optLevelParamKey)
+ if (optLevelStr.nonEmpty) {
+ try {
+ val optLevel = optLevelStr.get.toInt
+ setOptLevel(optLevel)
+ } catch {
+ case ex: Exception =>
+ throw new IllegalArgumentException(s"Parse boostkit parameter" +
+ s"($optLevelParamKey) failed, Error reason: ${ex.getMessage}")
+ }
+ }
+ }
+
+ private def setLevel1TimeLimitFromSparkConf(sc: SparkContext): Unit = {
+ val timeLimit1Str = sc.conf.getOption(timeLimit1ParamKey)
+ if (timeLimit1Str.nonEmpty) {
+ try {
+ val timeLimit1 = timeLimit1Str.get.toDouble
+ setTimeLimit1(timeLimit1)
+ } catch {
+ case ex: Exception =>
+ throw new IllegalArgumentException(s"Parse boostkit parameter" +
+ s"($timeLimit1ParamKey) failed, Error reason: ${ex.getMessage}")
+ }
+ }
+ }
+
+ private def setLevel2TimeLimitFromSparkConf(sc: SparkContext): Unit = {
+ val timeLimit1Str = sc.conf.getOption(timeLimit1ParamKey)
+ val timeLimit2Str = sc.conf.getOption(timeLimit2ParamKey)
+ if (timeLimit1Str.nonEmpty && timeLimit2Str.nonEmpty) {
+ try {
+ val timeLimit1 = timeLimit1Str.get.toDouble
+ val timeLimit2 = timeLimit2Str.get.toDouble
+ setTimeLimit1(timeLimit1)
+ setTimeLimit2(timeLimit2)
+ } catch {
+ case ex: Exception =>
+ throw new IllegalArgumentException(s"Parse boostkit parameter" +
+ s"($timeLimit1ParamKey) failed, Error reason: ${ex.getMessage}")
+ }
+ }
+ }
+
+ /**
+ * Java-friendly version of `run`.
+ */
+ @Since("1.3.0")
+ def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
+ implicit val tag = fakeClassTag[Item]
+ run(data.rdd.map(_.asScala.toArray))
+ }
+
+}
+
+@Since("1.3.0")
+object FPGrowth {
+
+ /**
+ * Frequent itemset.
+ * @param items items in this itemset. Java users should call `FreqItemset.javaItems` instead.
+ * @param freq frequency
+ * @tparam Item item type
+ *
+ */
+ @Since("1.3.0")
+ class FreqItemset[Item] @Since("1.3.0") (
+ @Since("1.3.0") val items: Array[Item],
+ @Since("1.3.0") val freq: Long) extends Serializable {
+
+ /**
+ * Returns items in a Java List.
+ *
+ */
+ @Since("1.3.0")
+ def javaItems: java.util.List[Item] = {
+ items.toList.asJava
+ }
+
+ override def toString: String = {
+ s"${items.mkString("{", ",", "}")}: $freq"
+ }
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 23fabc6598db2acf0507b666c0b55c56549f6635..ecf9ca7e238326491f065dcb199dca56ffa04995 100644
--- a/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -209,7 +209,7 @@ object PrefixSpan extends Logging {
data.flatMap { itemsets =>
val uniqItems = mutable.Set.empty[Item]
itemsets.foreach(set => uniqItems ++= set)
- uniqItems.toIterator.map((_, 1L))
+ uniqItems.iterator.map((_, 1L))
}.reduceByKey(_ + _).filter { case (_, count) =>
count >= minCount
}.sortBy(v => (-v._2, v._1.hashCode())).map(_._1).collect()
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
new file mode 100644
index 0000000000000000000000000000000000000000..81579d1f11b371d7bccbcae950a8c66156d54850
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import breeze.linalg.{DenseMatrix => BDM, DenseMatrixUtil, DenseVector => BDV}
+import breeze.linalg.blas.Dgemv
+import org.netlib.util.{doubleW, intW}
+
+/**
+ * Compute eigen-decomposition.
+ */
+private[mllib] object EigenValueDecomposition {
+
+ private val DEFAULT_THREAD_NUM = 35
+
+ /**
+ * Compute the leading k eigenvalues and eigenvectors on a symmetric square matrix using ARPACK.
+ * The caller needs to ensure that the input matrix is real symmetric. This function requires
+ * memory for `n*(4*k+4)` doubles.
+ *
+ * @param mul a function that multiplies the symmetric matrix with a DenseVector.
+ * @param n dimension of the square matrix (maximum Int.MaxValue).
+ * @param k number of leading eigenvalues required, where k must be positive and less than n.
+ * @param tol tolerance of the eigs computation.
+ * @param maxIterations the maximum number of Arnoldi update iterations.
+ * @return a dense vector of eigenvalues in descending order and a dense matrix of eigenvectors
+ * (columns of the matrix).
+ * @note The number of computed eigenvalues might be smaller than k when some Ritz values do not
+ * satisfy the convergence criterion specified by tol (see ARPACK Users Guide, Chapter 4.6
+ * for more details). The maximum number of Arnoldi update iterations is set to 300 in this
+ * function.
+ */
+ def symmetricEigs(
+ mul: BDV[Double] => BDV[Double],
+ n: Int,
+ k: Int,
+ tol: Double,
+ maxIterations: Int): (BDV[Double], BDM[Double]) = {
+ // TODO: remove this function and use eigs in breeze when switching breeze version
+ require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n")
+
+ // tolerance used in stopping criterion
+ val tolW = new doubleW(tol)
+ // number of desired eigenvalues, 0 < nev < n
+ val nev = new intW(k)
+ // nev Lanczos vectors are generated in the first iteration
+ // ncv-nev Lanczos vectors are generated in each subsequent iteration
+ // ncv must be smaller than n
+ val ncv = math.min(2 * k, n)
+
+ // "I" for standard eigenvalue problem, "G" for generalized eigenvalue problem
+ val bmat = "I"
+ // "LM" : compute the NEV largest (in magnitude) eigenvalues
+ val which = "LM"
+
+ val iparam = new Array[Int](11)
+ // use exact shift in each iteration
+ iparam(0) = 1
+ // maximum number of Arnoldi update iterations, or the actual number of iterations on output
+ iparam(2) = maxIterations
+ // Mode 1: A*x = lambda*x, A symmetric
+ iparam(6) = 1
+
+ require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE,
+ s"k = $k and/or n = $n are too large to compute an eigendecomposition")
+
+ val ido = new intW(0)
+ val info = new intW(0)
+ val resid = new Array[Double](n)
+ val v = new Array[Double](n * ncv)
+ val workd = new Array[Double](n * 3)
+ val workl = new Array[Double](ncv * (ncv + 8))
+ val ipntr = new Array[Int](11)
+
+ // call ARPACK's reverse communication, first iteration with ido = 0
+ ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv,
+ v, n, iparam, ipntr, workd, workl, workl.length, info)
+
+ val w = BDV(workd)
+
+ // ido = 99 : done flag in reverse communication
+ while (ido.`val` != 99) {
+ if (ido.`val` != -1 && ido.`val` != 1) {
+ throw new IllegalStateException("ARPACK returns ido = " + ido.`val` +
+ " This flag is not compatible with Mode 1: A*x = lambda*x, A symmetric.")
+ }
+ // multiply working vector with the matrix
+ val inputOffset = ipntr(0) - 1
+ val outputOffset = ipntr(1) - 1
+ val x = w.slice(inputOffset, inputOffset + n)
+ val y = w.slice(outputOffset, outputOffset + n)
+ y := mul(x)
+ // call ARPACK's reverse communication
+ ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv,
+ v, n, iparam, ipntr, workd, workl, workl.length, info)
+ }
+
+ if (info.`val` != 0) {
+ info.`val` match {
+ case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " Maximum number of iterations taken. (Refer ARPACK user guide for details)")
+ case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " No shifts could be applied. Try to increase NCV. " +
+ "(Refer ARPACK user guide for details)")
+ case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " Please refer ARPACK user guide for error message.")
+ }
+ }
+
+ val d = new Array[Double](nev.`val`)
+ val select = new Array[Boolean](ncv)
+ // copy the Ritz vectors
+ val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n)
+
+ // call ARPACK's post-processing for eigenvectors
+ ARPACK.nativeARPACK.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid,
+ ncv, v, n, iparam, ipntr, workd, workl, workl.length, info)
+
+ // number of computed eigenvalues, might be smaller than k
+ val computed = iparam(4)
+
+ val eigenPairs = java.util.Arrays.copyOfRange(d, 0, computed).zipWithIndex.map { r =>
+ (r._1, java.util.Arrays.copyOfRange(z, r._2 * n, r._2 * n + n))
+ }
+
+ // sort the eigen-pairs in descending order
+ val sortedEigenPairs = eigenPairs.sortBy(- _._1)
+
+ // copy eigenvectors in descending order of eigenvalues
+ val sortedU = BDM.zeros[Double](n, computed)
+ sortedEigenPairs.zipWithIndex.foreach { r =>
+ val b = r._2 * n
+ var i = 0
+ while (i < n) {
+ sortedU.data(b + i) = r._1._2(i)
+ i += 1
+ }
+ }
+
+ (BDV[Double](sortedEigenPairs.map(_._1)), sortedU)
+ }
+
+ def symmetricEigsLocal(
+ matrix: BDM[Double],
+ n: Int,
+ k: Int,
+ tol: Double,
+ maxIterations: Int,
+ driverCores: Int): (BDV[Double], BDM[Double]) = {
+ // TODO: remove this function and use eigs in breeze when switching breeze version
+ require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n")
+
+ val threadNum = math.min(
+ if (driverCores < 2) DEFAULT_THREAD_NUM else driverCores, matrix.rows)
+ val blocks = DenseMatrixUtil.blockByRow(matrix, threadNum)
+
+ // tolerance used in stopping criterion
+ val tolW = new doubleW(tol)
+ // number of desired eigenvalues, 0 < nev < n
+ val nev = new intW(k)
+ // nev Lanczos vectors are generated in the first iteration
+ // ncv-nev Lanczos vectors are generated in each subsequent iteration
+ // ncv must be smaller than n
+ val ncv = math.min(2 * k, n)
+
+ // "I" for standard eigenvalue problem, "G" for generalized eigenvalue problem
+ val bmat = "I"
+ // "LM" : compute the NEV largest (in magnitude) eigenvalues
+ val which = "LM"
+
+ var iparam = new Array[Int](11)
+ // use exact shift in each iteration
+ iparam(0) = 1
+ // maximum number of Arnoldi update iterations, or the actual number of iterations on output
+ iparam(2) = maxIterations
+ // Mode 1: A*x = lambda*x, A symmetric
+ iparam(6) = 1
+
+ require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE,
+ s"k = $k and/or n = $n are too large to compute an eigendecomposition")
+
+ val ido = new intW(0)
+ val info = new intW(0)
+ val resid = new Array[Double](n)
+ val v = new Array[Double](n * ncv)
+ val workd = new Array[Double](n * 3)
+ val workl = new Array[Double](ncv * (ncv + 8))
+ val ipntr = new Array[Int](11)
+
+ // call ARPACK's reverse communication, first iteration with ido = 0
+ ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv,
+ v, n, iparam, ipntr, workd, workl, workl.length, info)
+
+ val w = BDV(workd)
+
+ // ido = 99 : done flag in reverse communication
+ while (ido.`val` != 99) {
+ if (ido.`val` != -1 && ido.`val` != 1) {
+ throw new IllegalStateException("ARPACK returns ido = " + ido.`val` +
+ " This flag is not compatible with Mode 1: A*x = lambda*x, A symmetric.")
+ }
+
+ // multiply working vector with the matrix
+ val inputOffset = ipntr(0) - 1
+ val outputOffset = ipntr(1) - 1
+ val input = w.slice(inputOffset, inputOffset + n)
+ val output = Dgemv.compute(blocks, input)
+ System.arraycopy(output.data, 0, workd, outputOffset, n)
+
+ // call ARPACK's reverse communication
+ ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv,
+ v, n, iparam, ipntr, workd, workl, workl.length, info)
+ }
+
+ if (info.`val` != 0) {
+ info.`val` match {
+ case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " Maximum number of iterations taken. (Refer ARPACK user guide for details)")
+ case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " No shifts could be applied. Try to increase NCV. " +
+ "(Refer ARPACK user guide for details)")
+ case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " Please refer ARPACK user guide for error message.")
+ }
+ }
+
+ val d = new Array[Double](nev.`val`)
+ val select = new Array[Boolean](ncv)
+ // copy the Ritz vectors
+ val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n)
+
+ // call ARPACK's post-processing for eigenvectors
+ ARPACK.nativeARPACK.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid,
+ ncv, v, n, iparam, ipntr, workd, workl, workl.length, info)
+
+ // number of computed eigenvalues, might be smaller than k
+ val computed = iparam(4)
+
+ val eigenPairs = java.util.Arrays.copyOfRange(d, 0, computed).zipWithIndex.map { r =>
+ (r._1, java.util.Arrays.copyOfRange(z, r._2 * n, r._2 * n + n))
+ }
+
+ // sort the eigen-pairs in descending order
+ val sortedEigenPairs = eigenPairs.sortBy(- _._1)
+
+ // copy eigenvectors in descending order of eigenvalues
+ val sortedU = BDM.zeros[Double](n, computed)
+ sortedEigenPairs.zipWithIndex.foreach { r =>
+ val b = r._2 * n
+ var i = 0
+ while (i < n) {
+ sortedU.data(b + i) = r._1._2(i)
+ i += 1
+ }
+ }
+
+ (BDV[Double](sortedEigenPairs.map(_._1)), sortedU)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
new file mode 100644
index 0000000000000000000000000000000000000000..39b72ea622e7681574565123ab34fffb9b1391a5
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -0,0 +1,964 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg.distributed
+
+import java.util.Arrays
+
+import scala.collection.mutable.ListBuffer
+
+import breeze.linalg.{axpy => brzAxpy, inv, DenseMatrix => BDM, DenseVector => BDV, MatrixSingularException, SparseVector => BSV}
+import breeze.linalg.blas.Gramian
+import breeze.linalg.lapack.EigenDecomposition
+import breeze.linalg.lapack.EigenDecomposition.Eigen
+import breeze.numerics.{sqrt => brzSqrt}
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.MAX_RESULT_SIZE
+import org.apache.spark.ml.StaticUtils
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary, Statistics}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.random.XORShiftRandom
+
+
+
+/**
+ * Represents a row-oriented distributed Matrix with no meaningful row indices.
+ *
+ * @param rows rows stored as an RDD[Vector]
+ * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will
+ * be determined by the number of records in the RDD `rows`.
+ * @param nCols number of columns. A non-positive value means unknown, and then the number of
+ * columns will be determined by the size of the first row.
+ */
+@Since("1.0.0")
+class RowMatrix @Since("1.0.0")(
+ @Since("1.0.0") val rows: RDD[Vector],
+ private var nRows: Long,
+ private var nCols: Int) extends DistributedMatrix with Logging {
+
+ /** Alternative constructor leaving matrix dimensions to be determined automatically. */
+ @Since("1.0.0")
+ def this(rows: RDD[Vector]) = this(rows, 0L, 0)
+
+ /** Gets or computes the number of columns. */
+ @Since("1.0.0")
+ override def numCols(): Long = {
+ if (nCols <= 0) {
+ try {
+ // Calling `first` will throw an exception if `rows` is empty.
+ nCols = rows.first().size
+ } catch {
+ case err: UnsupportedOperationException =>
+ sys.error("Cannot determine the number of cols because it is not specified in the " +
+ "constructor and the rows RDD is empty.")
+ }
+ }
+ nCols
+ }
+
+ /** Gets or computes the number of rows. */
+ @Since("1.0.0")
+ override def numRows(): Long = {
+ if (nRows <= 0L) {
+ nRows = rows.count()
+ if (nRows == 0L) {
+ sys.error("Cannot determine the number of rows because it is not specified in the " +
+ "constructor and the rows RDD is empty.")
+ }
+ }
+ nRows
+ }
+
+ /**
+ * Multiplies the Gramian matrix `A^T A` by a dense vector on the right without computing `A^T A`.
+ *
+ * @param v a dense vector whose length must match the number of columns of this matrix
+ * @return a dense vector representing the product
+ */
+ private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
+ val n = numCols().toInt
+ val vbr = rows.context.broadcast(v)
+ rows.treeAggregate(null.asInstanceOf[BDV[Double]])(
+ seqOp = (U, r) => {
+ val rBrz = r.asBreeze
+ val a = rBrz.dot(vbr.value)
+ val theU =
+ if (U == null) {
+ BDV.zeros[Double](n)
+ } else {
+ U
+ }
+ rBrz match {
+ // use specialized axpy for better performance
+ case _: BDV[_] => brzAxpy(a, rBrz.asInstanceOf[BDV[Double]], theU)
+ case _: BSV[_] => brzAxpy(a, rBrz.asInstanceOf[BSV[Double]], theU)
+ case _ => throw new UnsupportedOperationException(
+ s"Do not support vector operation from type ${rBrz.getClass.getName}.")
+ }
+ theU
+ }, combOp = (U1, U2) => {
+ if (U1 == null) {
+ U2
+ } else if (U2 == null) {
+ U1
+ } else {
+ U1 += U2
+ U1
+ }
+ })
+ }
+
+ /**
+ * Computes the Gramian matrix `A^T A`.
+ *
+ * @note This cannot be computed on matrices with more than 65535 columns.
+ */
+ @Since("1.0.0")
+ def computeGramianMatrix(): Matrix = {
+ if (rows.map(_.isInstanceOf[SparseVector]).reduce((x, y) => x && y)) {
+ RowMatrixUtil.computeGramMatrixAsDenseMatrix(
+ rows.map(_.asInstanceOf[SparseVector]), numCols().toInt)
+ } else {
+ computeDenseGramianMatrix()
+ }
+ }
+
+
+ /**
+ * Compute the leading k eigenvalues and eigenvectors on a symmetric square sparse matrix.
+ *
+ * @param n dimension of the square matrix (maximum Int.MaxValue).
+ * @param k number of leading eigenvalues required, where k must be positive and less than n.
+ * @param tol tolerance of the eigs computation.
+ * @param maxIter the maximum number of Arnoldi update iterations.
+ * @return a dense vector of eigenvalues in descending order and a dense matrix of eigenvectors
+ * (columns of the matrix).
+ * @note The number of computed eigenvalues might be smaller than k when some Ritz values do not
+ * satisfy the convergence criterion specified by tol (see ARPACK Users Guide, Chapter 4.6
+ * for more details). The maximum number of Arnoldi update iterations is set to 300 in this
+ * function.
+ */
+ def eigenValueDecompositionOnSparseMatrix(
+ n: Int,
+ k: Int,
+ tol: Double,
+ maxIter: Int): (BDV[Double], BDM[Double]) = {
+ val result = RowMatrixUtil.computeGramMatrix(
+ rows.map(_.asInstanceOf[SparseVector]), n)
+ EigenValueDecomposition.symmetricEigs(
+ RowMatrixUtil.multiplySparseGramMatrixBy(result),
+ n, k, tol, maxIter)
+ }
+
+ /**
+ * Compute the leading k eigenvalues and eigenvectors on a symmetric square dense matrix.
+ *
+ * @param n dimension of the square matrix (maximum Int.MaxValue).
+ * @param k number of leading eigenvalues required, where k must be positive and less than n.
+ * @param tol tolerance of the eigs computation.
+ * @param maxIter the maximum number of Arnoldi update iterations.
+ * @return a dense vector of eigenvalues in descending order and a dense matrix of eigenvectors
+ * (columns of the matrix).
+ */
+ def eigenValueDecompositionOnDenseMatrix(
+ n: Int,
+ k: Int,
+ tol: Double,
+ maxIter: Int): (BDV[Double], BDM[Double]) = {
+ val result = RowMatrixUtil.computeGramMatrix(
+ rows.map(_.asInstanceOf[SparseVector]), n)
+ val resultDenseMatrix = result._3.map{case ((i, j), sp) =>
+ ((i, j), new BDM[Double](sp.numRows, sp.numCols, sp.toArray))}
+ val newResult = (result._1, result._2, resultDenseMatrix)
+ EigenValueDecomposition.symmetricEigs(
+ RowMatrixUtil.multiplyDenseGramMatrixBy(newResult),
+ n, k, tol, maxIter)
+ }
+
+ /**
+ * Computes the Gramian matrix `A^T A` of dense matrix.
+ * @return Gramian matrix
+ */
+ def computeDenseGramianMatrix(): Matrix = {
+ val n = numCols().toInt
+ checkNumColumns(n)
+
+ // compute the upper triangular matrix
+ val gramianLen = n * (n + 1) / 2
+ val gramian = rows.mapPartitions(iter => {
+ val subMatrixValues = iter.map(_.toArray).toArray
+ val subMatrixRow = subMatrixValues.length
+ val localCovariance = new Array[Double](gramianLen)
+ Gramian.compute(subMatrixValues.flatten, localCovariance, subMatrixRow, n)
+ Array(localCovariance).iterator
+ }).treeReduce((cov1, cov2) => {
+ blas.daxpy(cov1.length, 1.0, cov2, 1, cov1, 1)
+ cov1
+ }, depth = 4)
+
+ // full fill the gramian matrix
+ val fullGramian = new Array[Double](n * n)
+ for(i <- 0 until n) {
+ val srcOffset = (2 * n - i + 1) * i / 2
+ fullGramian(i * n + i) = gramian(srcOffset)
+ for(j <- i until n) {
+ val v = gramian(srcOffset + j - i)
+ fullGramian(i * n + j) = v
+ fullGramian(j * n + i) = v
+ }
+ }
+
+ new DenseMatrix(n, n, fullGramian)
+ }
+
+
+ private def checkNumColumns(cols: Int): Unit = {
+ if (cols > 65535) {
+ throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols")
+ }
+ if (cols > 10000) {
+ val memMB = (cols.toLong * cols) / 125000
+ logWarning(s"$cols columns will require at least $memMB megabytes of memory!")
+ }
+ }
+
+ /**
+ * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This
+ * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k
+ * singular values, U and V contain the corresponding singular vectors.
+ *
+ * At most k largest non-zero singular values and associated vectors are returned. If there are k
+ * such values, then the dimensions of the return will be:
+ * - U is a RowMatrix of size m x k that satisfies U' * U = eye(k),
+ * - s is a Vector of size k, holding the singular values in descending order,
+ * - V is a Matrix of size n x k that satisfies V' * V = eye(k).
+ *
+ * We assume n is smaller than m, though this is not strictly required.
+ * The singular values and the right singular vectors are derived
+ * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix
+ * storing the right singular vectors, is computed via matrix multiplication as
+ * U = A * (V * S^-1^), if requested by user. The actual method to use is determined
+ * automatically based on the cost:
+ * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute
+ * the Gramian matrix first and then compute its top eigenvalues and eigenvectors locally
+ * on the driver. This requires a single pass with O(n^2^) storage on each executor and
+ * on the driver, and O(n^2^ k) time on the driver.
+ * - Otherwise, we compute (A' * A) * v in a distributive way and send it to ARPACK's DSAUPD to
+ * compute (A' * A)'s top eigenvalues and eigenvectors on the driver node. This requires O(k)
+ * passes, O(n) storage on each executor, and O(n k) storage on the driver.
+ *
+ * Several internal parameters are set to default values. The reciprocal condition number rCond
+ * is set to 1e-9. All singular values smaller than rCond * sigma(0) are treated as zeros, where
+ * sigma(0) is the largest singular value. The maximum number of Arnoldi update iterations for
+ * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's
+ * eigen-decomposition is set to 1e-10.
+ *
+ * @param k number of leading singular values to keep (0 < k <= n).
+ * It might return less than k if
+ * there are numerically zero singular values or there are not enough Ritz values
+ * converged before the maximum number of Arnoldi update iterations is reached (in case
+ * that matrix A is ill-conditioned).
+ * @param computeU whether to compute U
+ * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0)
+ * are treated as zero, where sigma(0) is the largest singular value.
+ * @return SingularValueDecomposition(U, s, V). U = null if computeU = false.
+ * @note The conditions that decide which method to use internally and the default parameters are
+ * subject to change.
+ */
+ @Since("1.0.0")
+ def computeSVD(
+ k: Int,
+ computeU: Boolean = false,
+ rCond: Double = 1e-9): SingularValueDecomposition[RowMatrix, Matrix] = {
+ // maximum number of Arnoldi update iterations for invoking ARPACK
+ val maxIter = math.max(300, k * 3)
+ // numerical tolerance for invoking ARPACK
+ val tol = 1e-10
+ computeSVD(k, computeU, rCond, maxIter, tol, "auto")
+ }
+
+ /**
+ * The actual SVD implementation, visible for testing.
+ *
+ * @param k number of leading singular values to keep (0 < k <= n)
+ * @param computeU whether to compute U
+ * @param rCond the reciprocal condition number
+ * @param maxIter max number of iterations (if ARPACK is used)
+ * @param tol termination tolerance (if ARPACK is used)
+ * @param mode computation mode (auto: determine automatically which mode to use,
+ * local-svd: compute gram matrix and computes its full SVD locally,
+ * local-eigs: compute gram matrix and computes its top eigenvalues locally,
+ * dist-eigs: compute the top eigenvalues of the gram matrix distributively)
+ * @return SingularValueDecomposition(U, s, V). U = null if computeU = false.
+ */
+ private[mllib] def computeSVD(
+ k: Int,
+ computeU: Boolean,
+ rCond: Double,
+ maxIter: Int,
+ tol: Double,
+ mode: String): SingularValueDecomposition[RowMatrix, Matrix] = {
+ val n = numCols().toInt
+ require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
+
+ object SVDMode extends Enumeration {
+ val LocalARPACK, LocalLAPACK, DistARPACK = Value
+ }
+
+ val modeStr = if (mode == "auto") RowMatrixUtil.selectSVDBranch(n, k) else mode
+ val computeMode = modeStr match {
+ case "local-svd" => SVDMode.LocalLAPACK
+ case "local-eigs" => SVDMode.LocalARPACK
+ case "dist-eigs" => SVDMode.DistARPACK
+ case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
+ }
+
+ val isSparse: Boolean = rows.map(_.isInstanceOf[SparseVector]).reduce((x, y) => x && y)
+
+ // Compute the eigen-decomposition of A' * A.
+ val (sigmaSquares: BDV[Double], u: BDM[Double]) = computeMode match {
+ case SVDMode.LocalARPACK =>
+ require(k < n, s"k must be smaller than n in local-eigs mode but got k=$k and n=$n.")
+ if (isSparse) {
+ eigenValueDecompositionOnDenseMatrix(n, k, tol, maxIter)
+ } else {
+ val G = computeDenseGramianMatrix().asBreeze.asInstanceOf[BDM[Double]]
+ val driverCores = RowMatrixUtil.parseExtraParams(rows.sparkContext, -1)
+ EigenValueDecomposition.symmetricEigsLocal(G, n, k, tol, maxIter, driverCores)
+ }
+ case SVDMode.LocalLAPACK =>
+ // svd latent constraint, 2 * n * n + 6 * n + 1 < Int.MaxValue
+ require(n < 32767, s"$n exceeds the breeze svd capability")
+ val G = computeGramianMatrix().asBreeze.asInstanceOf[BDM[Double]]
+ val Eigen(uFull, sigmaSquaresFull) = EigenDecomposition.symmetricEigenDecomposition(G)
+ (sigmaSquaresFull, uFull)
+ case SVDMode.DistARPACK =>
+ if (rows.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data is not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+ require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.")
+ if (isSparse) {
+ eigenValueDecompositionOnSparseMatrix(n, k, tol, maxIter)
+ } else {
+ EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter)
+ }
+ }
+
+ val sigmas: BDV[Double] = brzSqrt(sigmaSquares)
+
+ // Determine the effective rank.
+ val sigma0 = sigmas(0)
+ val threshold = rCond * sigma0
+ var i = 0
+ // sigmas might have a length smaller than k, if some Ritz values do not satisfy the convergence
+ // criterion specified by tol after max number of iterations.
+ // Thus use i < min(k, sigmas.length) instead of i < k.
+ if (sigmas.length < k) {
+ logWarning(s"Requested $k singular values but only found ${sigmas.length} converged.")
+ }
+ while (i < math.min(k, sigmas.length) && sigmas(i) >= threshold) {
+ i += 1
+ }
+ val sk = i
+
+ if (sk < k) {
+ logWarning(s"Requested $k singular values but only found $sk nonzeros.")
+ }
+
+ // Warn at the end of the run as well, for increased visibility.
+ if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) {
+ logWarning("The input data was not directly cached, which may hurt performance if its"
+ + " parent RDDs are also uncached.")
+ }
+
+ val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk))
+ val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk))
+
+ if (computeU) {
+ // N = Vk * Sk^{-1}
+ val N = new BDM[Double](n, sk, Arrays.copyOfRange(u.data, 0, n * sk))
+ var i = 0
+ var j = 0
+ while (j < sk) {
+ i = 0
+ val sigma = sigmas(j)
+ while (i < n) {
+ N(i, j) /= sigma
+ i += 1
+ }
+ j += 1
+ }
+ val U = this.multiply(Matrices.fromBreeze(N))
+ SingularValueDecomposition(U, s, V)
+ } else {
+ SingularValueDecomposition(null, s, V)
+ }
+ }
+
+ /**
+ * Distributed algorithm of computing covariance matrix for a dense matrix with dimension (m,n).
+ * @param mean Mean value vector of size n
+ * @param n Column number
+ * @param m Row number
+ * @return Covariance matrix
+ */
+ private def computeDenseVectorCovariance(mean: Vector, n: Int, m: Long): Matrix = {
+ val meanBroadcast = rows.context.broadcast(mean)
+
+ // centralize matrix
+ val centralizedRows = rows.map(row => {
+ val mean = meanBroadcast.value
+ val centralizedRow = new Array[Double](n)
+ for (idx <- 0 until n)
+ centralizedRow(idx) = row(idx) - mean(idx)
+ Vectors.dense(centralizedRow)
+ })
+
+ // compute the upper triangular matrix
+ val covarianceLen = n * (n + 1) / 2
+ val covariance = centralizedRows.mapPartitions(iter => {
+ val subMatrixValues = iter.map(_.toArray).toArray
+ val subMatrixRow = subMatrixValues.length
+ val localCovariance = new Array[Double](covarianceLen)
+ Gramian.compute(subMatrixValues.flatten, localCovariance, subMatrixRow, n)
+ Array(localCovariance).iterator
+ }).treeReduce((cov1, cov2) => {
+ blas.daxpy(cov1.length, 1.0, cov2, 1, cov1, 1)
+ cov1
+ }, depth = 4)
+
+ // full fill the covariance matrix
+ val fullCovariance = new Array[Double](n * n)
+ val m1 = m - 1.0
+ for(i <- StaticUtils.ZERO_INT until n) {
+ val srcOffset = (2 * n - i + 1) * i / 2
+ fullCovariance(i * n + i) = covariance(srcOffset) / m1
+ for(j <- i + 1 until n) {
+ val v = covariance(srcOffset + j - i) / m1
+ fullCovariance(i * n + j) = v
+ fullCovariance(j * n + i) = v
+ }
+ }
+
+ new DenseMatrix(n, n, fullCovariance)
+ }
+
+ /**
+ * Distributed algorithm of computing covariance matrix for a sparse matrix with dimension (m,n).
+ * @param mean Mean value vector of size n
+ * @param n Column number
+ * @param m Row number
+ * @return Covariance matrix
+ */
+ def computeSparseVectorCovariance(mean: Vector, n: Int, m: Long): Matrix = {
+ val G = RowMatrixUtil.computeGramMatrixAsDenseMatrix(
+ rows.map(_.asInstanceOf[SparseVector]), n)
+ var i = 0
+ var j = 0
+ val m1 = m - 1.0
+ var alpha = 0.0
+ while (i < n) {
+ alpha = m / m1 * mean(i)
+ j = i
+ while (j < n) {
+ val Gij = G(i, j) / m1 - alpha * mean(j)
+ G(i, j) = Gij
+ G(j, i) = Gij
+ j += 1
+ }
+ i += 1
+ }
+ G
+ }
+
+ /**
+ * Compute covariance matrix with formula Cov(X, Y) = E[(X-E(X))(Y-E(Y))]
+ * @return Covariance matrix
+ */
+ def computeCovariance(): Matrix = {
+ val isSparse = rows.map(_.isInstanceOf[SparseVector]).reduce((x, y) => x && y)
+
+ val n = numCols().toInt
+ checkNumColumns(n)
+
+ val summary = Statistics.colStats(rows.map((_, 1.0)), Seq("count", "mean"))
+ val m = summary.count
+ require(m > 1, s"RowMatrix.computeCovariance called on matrix with only $m rows." +
+ " Cannot compute the covariance of a RowMatrix with <= 1 row.")
+ val mean = Vectors.fromML(summary.mean)
+
+ if (isSparse) {
+ computeSparseVectorCovariance(mean, n, m)
+ } else {
+ computeDenseVectorCovariance(mean, n, m)
+ }
+ }
+
+ /**
+ * Computes the top k principal components and a vector of proportions of
+ * variance explained by each principal component.
+ * Rows correspond to observations and columns correspond to variables.
+ * The principal components are stored a local matrix of size n-by-k.
+ * Each column corresponds for one principal component,
+ * and the columns are in descending order of component variance.
+ * The row data do not need to be "centered" first; it is not necessary for
+ * the mean of each column to be 0.
+ *
+ * @param k number of top principal components.
+ * @param mode number of top principal components.
+ * @return a matrix of size n-by-k, whose columns are principal components, and
+ * a vector of values which indicate how much variance each principal component
+ * explains
+ */
+ @Since("1.6.0")
+ def computePrincipalComponentsAndExplainedVarianceBody(
+ k: Int,
+ mode: String = "auto"): (Matrix, Vector) = {
+ val n = numCols().toInt
+ require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]")
+
+ object PCAMode extends Enumeration {
+ val Correlation, SVD, SparseSVD = Value
+ }
+ val checkSparseBranch = if (rows.map(_.isInstanceOf[SparseVector])
+ .reduce((x, y) => x && y)) {
+ PCAMode.SparseSVD
+ } else {
+ PCAMode.SVD
+ }
+ val computeMode = mode match {
+ case "Correlation" => PCAMode.Correlation
+ case "SVD" => checkSparseBranch
+ case _ =>
+ if (n == k || n < 1500) {
+ PCAMode.Correlation
+ } else {
+ checkSparseBranch
+ }
+ }
+ computeMode match {
+ case PCAMode.Correlation =>
+ val cov = computeCovariance().asBreeze.asInstanceOf[BDM[Double]]
+ val Eigen(u, s) = EigenDecomposition.symmetricEigenDecomposition(cov)
+
+ val eigenSum = s.data.sum
+ val explainedVariance = s.data.map(_ / eigenSum)
+ if (k == n) {
+ (Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance))
+ } else {
+ (Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)),
+ Vectors.dense(Arrays.copyOfRange(explainedVariance, 0, k)))
+ }
+ case PCAMode.SVD =>
+ val stas = Statistics.colStats(rows)
+ val meanVector = stas.mean.asBreeze
+ val centredMatrix = new RowMatrix(rows.map { rowVector =>
+ Vectors.fromBreeze(rowVector.asBreeze - meanVector)
+ })
+ val svd = centredMatrix.computeSVD(k)
+ val s = svd.s.toArray.map(eigValue => eigValue * eigValue / (numRows().toInt - 1))
+ val eigenSum = stas.variance.toArray.sum
+ val explainedVariance = s.map(_ / eigenSum)
+ (svd.V, Vectors.dense(explainedVariance))
+ case PCAMode.SparseSVD =>
+ val cov = computeCovariance().asBreeze.asInstanceOf[BDM[Double]]
+ val tol = 1e-10
+ val maxIter = math.max(3 * k, 300)
+ val newResult = RowMatrixUtil.toBlockMatrix(cov, rows.sparkContext)
+ val (sigmaSquares, u) = EigenValueDecomposition.symmetricEigs(
+ RowMatrixUtil.multiplyDenseGramMatrixBy(newResult),
+ n, k, tol, maxIter)
+ val eigenSum = Statistics.colStats(rows).variance.toArray.sum
+ val explainedVariance = sigmaSquares.toArray.map(_ / eigenSum)
+ (Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)),
+ Vectors.dense(explainedVariance))
+ }
+ }
+
+ /**
+ * Computes the top k principal components and a vector of proportions of
+ * variance explained by each principal component.
+ *
+ * @param k number of top principal components.
+ * @return a matrix of size n-by-k, whose columns are principal components, and
+ * a vector of values which indicate how much variance each principal component
+ * explains
+ */
+ def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = {
+ computePrincipalComponentsAndExplainedVarianceBody(k)
+ }
+
+ /**
+ * Computes the top k principal components only.
+ *
+ * @param k number of top principal components.
+ * @return a matrix of size n-by-k, whose columns are principal components
+ * @see computePrincipalComponentsAndExplainedVariance
+ */
+ @Since("1.0.0")
+ def computePrincipalComponents(k: Int): Matrix = {
+ computePrincipalComponentsAndExplainedVariance(k)._1
+ }
+
+ /**
+ * Computes column-wise summary statistics.
+ */
+ @Since("1.0.0")
+ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
+ val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
+ (aggregator, data) => aggregator.add(data),
+ (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
+ updateNumRows(summary.count)
+ summary
+ }
+
+ /**
+ * Multiply this matrix by a local matrix on the right.
+ *
+ * @param B a local matrix whose number of rows must match the number of columns of this matrix
+ * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product,
+ * which preserves partitioning
+ */
+ @Since("1.0.0")
+ def multiply(B: Matrix): RowMatrix = {
+ val n = numCols().toInt
+ val k = B.numCols
+ require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}")
+
+ require(B.isInstanceOf[DenseMatrix],
+ s"Only support dense matrix at this time but found ${B.getClass.getName}.")
+
+ val Bb = rows.context.broadcast(B.asBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
+ val AB = rows.mapPartitions { iter =>
+ val Bi = Bb.value
+ iter.map { row =>
+ val v = BDV.zeros[Double](k)
+ var i = 0
+ while (i < k) {
+ v(i) = row.asBreeze.dot(new BDV(Bi, i * n, 1, n))
+ i += 1
+ }
+ Vectors.fromBreeze(v)
+ }
+ }
+
+ new RowMatrix(AB, nRows, B.numCols)
+ }
+
+ /**
+ * Compute all cosine similarities between columns of this matrix using the brute-force
+ * approach of computing normalized dot products.
+ *
+ * @return An n x n sparse upper-triangular matrix of cosine similarities between
+ * columns of this matrix.
+ */
+ @Since("1.2.0")
+ def columnSimilarities(): CoordinateMatrix = {
+ columnSimilarities(0.0)
+ }
+
+ /**
+ * Compute similarities between columns of this matrix using a sampling approach.
+ *
+ * The threshold parameter is a trade-off knob between estimate quality and computational cost.
+ *
+ * Setting a threshold of 0 guarantees deterministic correct results, but comes at exactly
+ * the same cost as the brute-force approach. Setting the threshold to positive values
+ * incurs strictly less computational cost than the brute-force approach, however the
+ * similarities computed will be estimates.
+ *
+ * The sampling guarantees relative-error correctness for those pairs of columns that have
+ * similarity greater than the given similarity threshold.
+ *
+ * To describe the guarantee, we set some notation:
+ * Let A be the smallest in magnitude non-zero element of this matrix.
+ * Let B be the largest in magnitude non-zero element of this matrix.
+ * Let L be the maximum number of non-zeros per row.
+ *
+ * For example, for {0,1} matrices: A=B=1.
+ * Another example, for the Netflix matrix: A=1, B=5
+ *
+ * For those column pairs that are above the threshold,
+ * the computed similarity is correct to within 20% relative error with probability
+ * at least 1 - (0.981)^10/B^
+ *
+ * The shuffle size is bounded by the *smaller* of the following two expressions:
+ *
+ * O(n log(n) L / (threshold * A))
+ * O(m L^2^)
+ *
+ * The latter is the cost of the brute-force approach, so for non-zero thresholds,
+ * the cost is always cheaper than the brute-force approach.
+ *
+ * @param threshold Set to 0 for deterministic guaranteed correctness.
+ * Similarities above this threshold are estimated
+ * with the cost vs estimate quality trade-off described above.
+ * @return An n x n sparse upper-triangular matrix of cosine similarities
+ * between columns of this matrix.
+ */
+ @Since("1.2.0")
+ def columnSimilarities(threshold: Double): CoordinateMatrix = {
+ require(threshold >= 0, s"Threshold cannot be negative: $threshold")
+
+ if (threshold > 1) {
+ logWarning(s"Threshold is greater than 1: $threshold " +
+ "Computation will be more efficient with promoted sparsity, " +
+ " however there is no correctness guarantee.")
+ }
+
+ val gamma = if (threshold < 1e-6) {
+ Double.PositiveInfinity
+ } else {
+ 10 * math.log(numCols()) / threshold
+ }
+
+ val summary = Statistics.colStats(rows.map((_, 1.0)), Seq("normL2"))
+ columnSimilaritiesDIMSUM(summary.normL2.toArray, gamma)
+ }
+
+ /**
+ * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR
+ * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape.
+ * Reference:
+ * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce
+ * architectures" (see here)
+ *
+ * @param computeQ whether to computeQ
+ * @return QRDecomposition(Q, R), Q = null if computeQ = false.
+ */
+ @Since("1.5.0")
+ def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = {
+ val col = numCols().toInt
+ // split rows horizontally into smaller matrices, and compute QR for each of them
+ val blockQRs = rows.retag(classOf[Vector]).glom().filter(_.length != 0).map { partRows =>
+ val bdm = BDM.zeros[Double](partRows.length, col)
+ var i = 0
+ partRows.foreach { row =>
+ bdm(i, ::) := row.asBreeze.t
+ i += 1
+ }
+ breeze.linalg.qr.reduced(bdm).r
+ }
+
+ // combine the R part from previous results vertically into a tall matrix
+ val combinedR = blockQRs.treeReduce { (r1, r2) =>
+ val stackedR = BDM.vertcat(r1, r2)
+ breeze.linalg.qr.reduced(stackedR).r
+ }
+
+ val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix)
+ val finalQ = if (computeQ) {
+ try {
+ val invR = inv(combinedR)
+ this.multiply(Matrices.fromBreeze(invR))
+ } catch {
+ case err: MatrixSingularException =>
+ logWarning("R is not invertible and return Q as null")
+ null
+ }
+ } else {
+ null
+ }
+ QRDecomposition(finalQ, finalR)
+ }
+
+ /**
+ * Find all similar columns using the DIMSUM sampling algorithm, described in two papers
+ *
+ * http://arxiv.org/abs/1206.2082
+ * http://arxiv.org/abs/1304.1467
+ *
+ * @param colMags A vector of column magnitudes
+ * @param gamma The oversampling parameter. For provable results, set to 10 * log(n) / s,
+ * where s is the smallest similarity score to be estimated,
+ * and n is the number of columns
+ * @return An n x n sparse upper-triangular matrix of cosine similarities
+ * between columns of this matrix.
+ */
+ private[mllib] def columnSimilaritiesDIMSUM(
+ colMags: Array[Double],
+ gamma: Double): CoordinateMatrix = {
+ require(gamma > 1.0, s"Oversampling should be greater than 1: $gamma")
+ require(colMags.size == this.numCols(), "Number of magnitudes didn't match column dimension")
+ val sg = math.sqrt(gamma) // sqrt(gamma) used many times
+
+ // Don't divide by zero for those columns with zero magnitude
+ val colMagsCorrected = colMags.map(x => if (x == 0) 1.0 else x)
+
+ val sc = rows.context
+ val pBV = sc.broadcast(colMagsCorrected.map(c => sg / c))
+ val qBV = sc.broadcast(colMagsCorrected.map(c => math.min(sg, c)))
+
+ val sims = rows.mapPartitionsWithIndex { (index, iter) =>
+ val p = pBV.value
+ val q = qBV.value
+
+ val rand = new XORShiftRandom(index)
+ val scaled = new Array[Double](p.size)
+ iter.flatMap { row =>
+ row match {
+ case SparseVector(size, indices, values) =>
+ val nnz = indices.size
+ var k = 0
+ while (k < nnz) {
+ scaled(k) = values(k) / q(indices(k))
+ k += 1
+ }
+
+ Iterator.tabulate (nnz) { k =>
+ val buf = new ListBuffer[((Int, Int), Double)]()
+ val i = indices(k)
+ val iVal = scaled(k)
+ if (iVal != 0 && rand.nextDouble() < p(i)) {
+ var l = k + 1
+ while (l < nnz) {
+ val j = indices(l)
+ val jVal = scaled(l)
+ if (jVal != 0 && rand.nextDouble() < p(j)) {
+ buf += (((i, j), iVal * jVal))
+ }
+ l += 1
+ }
+ }
+ buf
+ }.flatten
+ case DenseVector(values) =>
+ val n = values.size
+ var i = 0
+ while (i < n) {
+ scaled(i) = values(i) / q(i)
+ i += 1
+ }
+ Iterator.tabulate (n) { i =>
+ val buf = new ListBuffer[((Int, Int), Double)]()
+ val iVal = scaled(i)
+ if (iVal != 0 && rand.nextDouble() < p(i)) {
+ var j = i + 1
+ while (j < n) {
+ val jVal = scaled(j)
+ if (jVal != 0 && rand.nextDouble() < p(j)) {
+ buf += (((i, j), iVal * jVal))
+ }
+ j += 1
+ }
+ }
+ buf
+ }.flatten
+ case v =>
+ throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.")
+ }
+ }
+ }.reduceByKey(_ + _).map { case ((i, j), sim) =>
+ MatrixEntry(i.toLong, j.toLong, sim)
+ }
+ new CoordinateMatrix(sims, numCols(), numCols())
+ }
+
+ private[mllib] override def toBreeze(): BDM[Double] = {
+ val m = numRows().toInt
+ val n = numCols().toInt
+ val mat = BDM.zeros[Double](m, n)
+ var i = 0
+ rows.collect().foreach { vector =>
+ vector.foreachNonZero { case (j, v) =>
+ mat(i, j) = v
+ }
+ i += 1
+ }
+ mat
+ }
+
+ /** Updates or verifies the number of rows. */
+ private def updateNumRows(m: Long): Unit = {
+ if (nRows <= 0) {
+ nRows = m
+ } else {
+ require(nRows == m,
+ s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
+ }
+ }
+
+ /**
+ * Computing desired tree aggregate depth necessary to avoid exceeding
+ * driver.MaxResultSize during aggregation.
+ * Based on the formulae: (numPartitions)^(1/depth) * objectSize <= DriverMaxResultSize
+ * @param aggregatedObjectSizeInBytes the size, in megabytes, of the object being tree aggregated
+ */
+ private[spark] def getTreeAggregateIdealDepth(aggregatedObjectSizeInBytes: Long): Int = {
+ require(aggregatedObjectSizeInBytes > 0,
+ "Cannot compute aggregate depth heuristic based on a zero-size object to aggregate")
+
+ val maxDriverResultSizeInBytes = rows.conf.get[Long](MAX_RESULT_SIZE)
+ if (maxDriverResultSizeInBytes <= 0) {
+ // Unlimited result size, so 1 is OK
+ return 1
+ }
+
+ require(maxDriverResultSizeInBytes > aggregatedObjectSizeInBytes,
+ s"Cannot aggregate object of size $aggregatedObjectSizeInBytes Bytes, "
+ + s"as it's bigger than maxResultSize ($maxDriverResultSizeInBytes Bytes)")
+
+ val numerator = math.log(rows.getNumPartitions)
+ val denominator = math.log(maxDriverResultSizeInBytes) - math.log(aggregatedObjectSizeInBytes)
+ val desiredTreeDepth = math.ceil(numerator / denominator)
+
+ if (desiredTreeDepth > 4) {
+ logWarning(
+ s"Desired tree depth for treeAggregation is big ($desiredTreeDepth)."
+ + "Consider increasing driver max result size or reducing number of partitions")
+ }
+
+ math.min(math.max(1, desiredTreeDepth), 10).toInt
+ }
+}
+
+@Since("1.0.0")
+object RowMatrix {
+
+ /**
+ * Fills a full square matrix from its upper triangular part.
+ */
+ private def triuToFull(n: Int, U: Array[Double]): Matrix = {
+ val G = new BDM[Double](n, n)
+
+ var row = 0
+ var col = 0
+ var idx = 0
+ var value = 0.0
+ while (col < n) {
+ row = 0
+ while (row < col) {
+ value = U(idx)
+ G(row, col) = value
+ G(col, row) = value
+ idx += 1
+ row += 1
+ }
+ G(col, col) = U(idx)
+ idx += 1
+ col += 1
+ }
+
+ Matrices.dense(n, n, G.data)
+ }
+}
diff --git a/ml-accelerator/src/main/scala/org/apache/spark/mllib/optimization/LBFGSN.scala b/ml-accelerator/src/main/scala/org/apache/spark/mllib/optimization/LBFGSN.scala
new file mode 100644
index 0000000000000000000000000000000000000000..acb9e06aa1871bc20f18ac99a9f190a37f0d28d8
--- /dev/null
+++ b/ml-accelerator/src/main/scala/org/apache/spark/mllib/optimization/LBFGSN.scala
@@ -0,0 +1,396 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.optimization
+
+import scala.collection.mutable
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGSN => BreezeLBFGSN}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.axpy
+import org.apache.spark.rdd.RDD
+
+/**
+ * Class used to solve an optimization problem using Limited-memory BFGS.
+ * Reference:
+ * Wikipedia on Limited-memory BFGS
+ * @param gradient Gradient function to be used.
+ * @param updater Updater to be used to update weights after every iteration.
+ */
+class LBFGSN(private var gradient: Gradient, private var updater: Updater)
+ extends Optimizer with Logging {
+
+ private var numCorrections = 10
+ private var convergenceTol = 1E-6
+ private var maxNumIterations = 100
+ private var regParam = 0.0
+
+ /**
+ * Set the number of corrections used in the LBFGS update. Default 10.
+ * Values of numCorrections less than 3 are not recommended; large values
+ * of numCorrections will result in excessive computing time.
+ * numCorrections must be positive, and values from 4 to 9 are generally recommended.
+ */
+ def setNumCorrections(corrections: Int): this.type = {
+ require(corrections > 0,
+ s"Number of corrections must be positive but got ${corrections}")
+ this.numCorrections = corrections
+ this
+ }
+
+ /**
+ * Set the convergence tolerance of iterations for L-BFGS. Default 1E-6.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * This value must be nonnegative. Lower convergence values are less tolerant
+ * and therefore generally cause more iterations to be run.
+ */
+ def setConvergenceTol(tolerance: Double): this.type = {
+ require(tolerance >= 0,
+ s"Convergence tolerance must be nonnegative but got ${tolerance}")
+ this.convergenceTol = tolerance
+ this
+ }
+
+ /*
+ * Get the convergence tolerance of iterations.
+ */
+ private[mllib] def getConvergenceTol(): Double = {
+ this.convergenceTol
+ }
+
+ /**
+ * Set the maximal number of iterations for L-BFGS. Default 100.
+ */
+ def setNumIterations(iters: Int): this.type = {
+ require(iters >= 0,
+ s"Maximum of iterations must be nonnegative but got ${iters}")
+ this.maxNumIterations = iters
+ this
+ }
+
+ /**
+ * Get the maximum number of iterations for L-BFGS. Defaults to 100.
+ */
+ private[mllib] def getNumIterations(): Int = {
+ this.maxNumIterations
+ }
+
+ /**
+ * Set the regularization parameter. Default 0.0.
+ */
+ def setRegParam(regParam: Double): this.type = {
+ require(regParam >= 0,
+ s"Regularization parameter must be nonnegative but got ${regParam}")
+ this.regParam = regParam
+ this
+ }
+
+ /**
+ * Get the regularization parameter.
+ */
+ private[mllib] def getRegParam(): Double = {
+ this.regParam
+ }
+
+ /**
+ * Set the gradient function (of the loss function of one single data example)
+ * to be used for L-BFGS.
+ */
+ def setGradient(gradient: Gradient): this.type = {
+ this.gradient = gradient
+ this
+ }
+
+ /**
+ * Set the updater function to actually perform a gradient step in a given direction.
+ * The updater is responsible to perform the update from the regularization term as well,
+ * and therefore determines what kind or regularization is used, if any.
+ */
+ def setUpdater(updater: Updater): this.type = {
+ this.updater = updater
+ this
+ }
+
+ /**
+ * Returns the updater, limited to internal use.
+ */
+ private[mllib] def getUpdater(): Updater = {
+ updater
+ }
+
+ override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
+ val (weights, _) = optimizeWithLossReturned(data, initialWeights)
+ weights
+ }
+
+ def optimizeWithLossReturned(
+ data: RDD[(Double, Vector)],
+ initialWeights: Vector): (Vector, Array[Double]) = {
+ LBFGSN.runLBFGS(
+ data,
+ gradient,
+ updater,
+ numCorrections,
+ convergenceTol,
+ maxNumIterations,
+ regParam,
+ initialWeights)
+ }
+}
+
+/**
+ * Top-level method to run L-BFGS.
+ */
+object LBFGSN extends Logging {
+ /**
+ * Run Limited-memory BFGS (L-BFGS) in parallel.
+ * Averaging the subgradients over different partitions is performed using one standard
+ * spark map-reduce in each iteration.
+ *
+ * @param data - Input data for L-BFGS. RDD of the set of data examples, each of
+ * the form (label, [feature values]).
+ * @param gradient - Gradient object (used to compute the gradient of the loss function of
+ * one single data example)
+ * @param updater - Updater function to actually perform a gradient step in a given direction.
+ * @param numCorrections - The number of corrections used in the L-BFGS update.
+ * @param convergenceTol - The convergence tolerance of iterations for L-BFGS which is must be
+ * nonnegative. Lower values are less tolerant and therefore generally
+ * cause more iterations to be run.
+ * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run.
+ * @param regParam - Regularization parameter
+ *
+ * @return A tuple containing two elements. The first element is a column matrix containing
+ * weights for every feature, and the second element is an array containing the loss
+ * computed for every iteration.
+ */
+ def runLBFGS(
+ data: RDD[(Double, Vector)],
+ gradient: Gradient,
+ updater: Updater,
+ numCorrections: Int,
+ convergenceTol: Double,
+ maxNumIterations: Int,
+ regParam: Double,
+ initialWeights: Vector): (Vector, Array[Double]) = {
+
+ val lossHistory = mutable.ArrayBuilder.make[Double]
+
+ val numExamples = data.count()
+
+ val costFunOpt = parseSparkBooleanParam(data.context,
+ "spark.boostkit.mllib.optimization.LBFGSN.costFun.opt", true)
+ val absoluteConvergenceCheck = parseSparkBooleanParam(data.context,
+ "spark.boostkit.mllib.optimization.LBFGSN.absoluteConvergenceCheck", true)
+ val fValMemory = parseSparkIntParam(data.context,
+ "spark.boostkit.mllib.optimization.LBFGSN.fValMemory", 2)
+
+ val costFun = if (costFunOpt) {
+ new CostFunY(data, gradient, updater, regParam, numExamples)
+ } else {
+ new CostFun(data, gradient, updater, regParam, numExamples)
+ }
+
+ val lbfgsn = new BreezeLBFGSN(maxNumIterations, numCorrections, convergenceTol,
+ absoluteConvergenceCheck, fValMemory)
+
+ val states =
+ lbfgsn.iterations(new CachedDiffFunction(costFun), initialWeights.asBreeze.toDenseVector)
+
+ /**
+ * NOTE: lossSum and loss is computed using the weights from the previous iteration
+ * and regVal is the regularization value computed in the previous iteration as well.
+ */
+ var state = states.next()
+ while (states.hasNext) {
+ lossHistory += state.value
+ state = states.next()
+ }
+ lossHistory += state.value
+
+ val weights = Vectors.fromBreeze(state.x)
+
+ val lossHistoryArray = lossHistory.result()
+
+ logInfo("LBFGSN.runLBFGS finished. Last 10 losses %s".format(
+ lossHistoryArray.takeRight(10).mkString(", ")))
+
+ if (state.searchFailed && costFunOpt) {
+ logError("please consider set spark parameter " +
+ "[spark.boostkit.mllib.optimization.LBFGSN.costFun.opt] to false and try again.")
+ }
+
+ (weights, lossHistoryArray)
+ }
+
+ /**
+ * CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient
+ * at a particular point (weights). It's used in Breeze's convex optimization routines.
+ */
+ private class CostFun(
+ data: RDD[(Double, Vector)],
+ gradient: Gradient,
+ updater: Updater,
+ regParam: Double,
+ numExamples: Long) extends DiffFunction[BDV[Double]] {
+
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ // Have a local copy to avoid the serialization of CostFun object which is not serializable.
+ val w = Vectors.fromBreeze(weights)
+ val n = w.size
+ val bcW = data.context.broadcast(w)
+ val localGradient = gradient
+
+ val seqOp = (c: (Vector, Double), v: (Double, Vector)) =>
+ (c, v) match {
+ case ((grad, loss), (label, features)) =>
+ val denseGrad = grad.toDense
+ val l = localGradient.compute(features, label, bcW.value, denseGrad)
+ (denseGrad, loss + l)
+ }
+
+ val combOp = (c1: (Vector, Double), c2: (Vector, Double)) =>
+ (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
+ val denseGrad1 = grad1.toDense
+ val denseGrad2 = grad2.toDense
+ axpy(1.0, denseGrad2, denseGrad1)
+ (denseGrad1, loss1 + loss2)
+ }
+
+ val zeroSparseVector = Vectors.sparse(n, Seq.empty)
+ val (gradientSum, lossSum) = data.treeAggregate((zeroSparseVector, 0.0))(seqOp, combOp)
+
+ // broadcasted model is not needed anymore
+ bcW.destroy()
+
+ /**
+ * regVal is sum of weight squares if it's L2 updater;
+ * for other updater, the same logic is followed.
+ */
+ val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2
+
+ val loss = lossSum / numExamples + regVal
+ /**
+ * It will return the gradient part of regularization using updater.
+ *
+ * Given the input parameters, the updater basically does the following,
+ *
+ * w' = w - thisIterStepSize * (gradient + regGradient(w))
+ * Note that regGradient is function of w
+ *
+ * If we set gradient = 0, thisIterStepSize = 1, then
+ *
+ * regGradient(w) = w - w'
+ *
+ * TODO: We need to clean it up by separating the logic of regularization out
+ * from updater to regularizer.
+ */
+ // The following gradientTotal is actually the regularization part of gradient.
+ // Will add the gradientSum computed from the data with weights in the next step.
+ val gradientTotal = w.copy
+ axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal)
+
+ // gradientTotal = gradientSum / numExamples + gradientTotal
+ axpy(1.0 / numExamples, gradientSum, gradientTotal)
+
+ (loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]])
+ }
+ }
+
+ private class CostFunY(
+ data: RDD[(Double, Vector)],
+ gradient: Gradient,
+ updater: Updater,
+ regParam: Double,
+ numExamples: Long) extends DiffFunction[BDV[Double]] {
+
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ // Have a local copy to avoid the serialization of CostFun object which is not serializable.
+ val w = Vectors.fromBreeze(weights)
+ val n = w.size
+ val bcW = data.context.broadcast(w)
+
+ val (gradientSum, lossSum) = CostFunOpt.aggregate(data, n, gradient, bcW)
+
+ // broadcasted model is not needed anymore
+ bcW.destroy()
+
+ /**
+ * regVal is sum of weight squares if it's L2 updater;
+ * for other updater, the same logic is followed.
+ */
+ val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2
+
+ val loss = lossSum / numExamples + regVal
+ /**
+ * It will return the gradient part of regularization using updater.
+ *
+ * Given the input parameters, the updater basically does the following,
+ *
+ * w' = w - thisIterStepSize * (gradient + regGradient(w))
+ * Note that regGradient is function of w
+ *
+ * If we set gradient = 0, thisIterStepSize = 1, then
+ *
+ * regGradient(w) = w - w'
+ *
+ * TODO: We need to clean it up by separating the logic of regularization out
+ * from updater to regularizer.
+ */
+ // The following gradientTotal is actually the regularization part of gradient.
+ // Will add the gradientSum computed from the data with weights in the next step.
+ val gradientTotal = w.copy
+ axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal)
+
+ // gradientTotal = gradientSum / numExamples + gradientTotal
+ axpy(1.0 / numExamples, gradientSum, gradientTotal)
+
+ (loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]])
+ }
+ }
+
+ private def parseSparkBooleanParam(sc: SparkContext, sparkParamName: String,
+ defaultValue: Boolean): Boolean = {
+ var param = defaultValue
+ try {
+ param = sc.getConf.getBoolean(sparkParamName, defaultValue)
+ } catch {
+ case e: Exception =>
+ throw new IllegalArgumentException(s"parse spark parameter" +
+ s"($sparkParamName) failed, Error reason: ${e.getMessage}")
+ }
+
+ param
+ }
+
+ private def parseSparkIntParam(sc: SparkContext, sparkParamName: String, defaultValue: Int)
+ : Int = {
+ var param = defaultValue
+ try {
+ param = sc.getConf.getInt(sparkParamName, defaultValue)
+ } catch {
+ case e: Exception =>
+ throw new IllegalArgumentException(s"parse spark parameter" +
+ s"($sparkParamName) failed, Error reason: ${e.getMessage}")
+ }
+
+ param
+ }
+}
diff --git a/ml-core/pom.xml b/ml-core/pom.xml
index 38cb635ec1d81c039c353ddf5226a148df4fad95..b87bd36dae07402afe00ffcc6735f77ae94a609c 100644
--- a/ml-core/pom.xml
+++ b/ml-core/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.2.0
+ 3.0.0
4.0.0
boostkit-ml-core_2.12
- 2.2.0
+ 3.0.0
${project.artifactId}
Spark ml core
@@ -15,9 +15,9 @@
org.apache.spark
boostkit-ml-kernel-client-core_2.12
- 2.2.0
+ 3.0.0
${spark.version}
- compile
+ provided
@@ -36,6 +36,9 @@
+ -unchecked
+ -deprecation
+ -feature
-dependencyfile
${project.build.directory}/.scala_dependencies
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java
deleted file mode 100644
index e6cc5fe0f6504378b383b16f3a3d13a9f35076e1..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/BLAS.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib;
-
-public interface BLAS {
-
- public static BLAS getInstance() {
- return InstanceBuilder.BLAS.getInstance();
- }
-
- public double dasum(int n, double[] x, int incx);
- public double dasum(int n, double[] x, int offsetx, int incx);
-
- public float sasum(int n, float[] x, int incx);
- public float sasum(int n, float[] x, int offsetx, int incx);
-
- public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy);
- public void daxpy(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public void saxpy(int n, float alpha, float[] x, int incx, float[] y, int incy);
- public void saxpy(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public void dcopy(int n, double[] x, int incx, double[] y, int incy);
- public void dcopy(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public void scopy(int n, float[] x, int incx, float[] y, int incy);
- public void scopy(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public double ddot(int n, double[] x, int incx, double[] y, int incy);
- public double ddot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public float sdot(int n, float[] x, int incx, float[] y, int incy);
- public float sdot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public float sdsdot(int n, float sb, float[] sx, int incx, float[] sy, int incy);
- public float sdsdot(int n, float sb, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy);
-
- public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
- public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
- public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc);
- public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc);
- public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc);
-
- public void dgemv(String trans, int m, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
- public void dgemv(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void sgemv(String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
- public void sgemv(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dger(int m, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda);
- public void dger(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
-
- public void sger(int m, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda);
- public void sger(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
-
- public double dnrm2(int n, double[] x, int incx);
- public double dnrm2(int n, double[] x, int offsetx, int incx);
-
- public float snrm2(int n, float[] x, int incx);
- public float snrm2(int n, float[] x, int offsetx, int incx);
-
- public void drot(int n, double[] dx, int incx, double[] dy, int incy, double c, double s);
- public void drot(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double c, double s);
-
- public void srot(int n, float[] sx, int incx, float[] sy, int incy, float c, float s);
- public void srot(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float c, float s);
-
- public void drotg(org.netlib.util.doubleW da, org.netlib.util.doubleW db, org.netlib.util.doubleW c, org.netlib.util.doubleW s);
-
- public void srotg(org.netlib.util.floatW sa, org.netlib.util.floatW sb, org.netlib.util.floatW c, org.netlib.util.floatW s);
-
- public void drotm(int n, double[] dx, int incx, double[] dy, int incy, double[] dparam);
- public void drotm(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double[] dparam, int offsetdparam);
-
- public void srotm(int n, float[] sx, int incx, float[] sy, int incy, float[] sparam);
- public void srotm(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float[] sparam, int offsetsparam);
-
- public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam);
- public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam, int offsetdparam);
-
- public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam);
- public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam, int offsetsparam);
-
- public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
- public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
- public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dscal(int n, double alpha, double[] x, int incx);
- public void dscal(int n, double alpha, double[] x, int offsetx, int incx);
-
- public void sscal(int n, float alpha, float[] x, int incx);
- public void sscal(int n, float alpha, float[] x, int offsetx, int incx);
-
- public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int incx, double beta, double[] y, int incy);
- public void dspmv(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void sspmv(String uplo, int n, float alpha, float[] ap, float[] x, int incx, float beta, float[] y, int incy);
- public void sspmv(String uplo, int n, float alpha, float[] ap, int offsetap, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a);
- public void dspr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta);
-
- public void sspr(String uplo, int n, float alpha, float[] x, int incx, float[] ap);
- public void sspr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] ap, int offsetap);
-
- public void dspr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] ap);
- public void dspr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] ap, int offsetap);
-
- public void sspr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] ap);
- public void sspr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] ap, int offsetap);
-
- public void dswap(int n, double[] x, int incx, double[] y, int incy);
- public void dswap(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public void sswap(int n, float[] x, int incx, float[] y, int incy);
- public void sswap(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int Ldc);
- public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int Ldc);
-
- public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc);
- public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc);
-
- public void dsymv(String uplo, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy);
- public void dsymv(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void ssymv(String uplo, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy);
- public void ssymv(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda);
- public void dsyr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda);
-
- public void ssyr(String uplo, int n, float alpha, float[] x, int incx, float[] a, int lda);
- public void ssyr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda);
-
- public void dsyr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda);
- public void dsyr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
-
- public void ssyr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda);
- public void ssyr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
-
- public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int Ldc);
- public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int Ldc);
-
- public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int Ldc);
- public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int Ldc);
-
- public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double beta, double[] c, int Ldc);
- public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int Ldc);
-
- public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float beta, float[] c, int Ldc);
- public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int Ldc);
-
- public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx);
- public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx);
- public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx);
- public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx);
- public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public void dtpmv(String uplo, String trans, String diag, int n, double[] ap, double[] x, int incx);
- public void dtpmv(String uplo, String trans, String diag, int n, double[] ap, int offsetap, double[] x, int offsetx, int incx);
-
- public void stpmv(String uplo, String trans, String diag, int n, float[] ap, float[] x, int incx);
- public void stpmv(String uplo, String trans, String diag, int n, float[] ap, int offsetap, float[] x, int offsetx, int incx);
-
- public void dtpsv(String uplo, String trans, String diag, int n, double[] ap, double[] x, int incx);
- public void dtpsv(String uplo, String trans, String diag, int n, double[] ap, int offsetap, double[] x, int offsetx, int incx);
-
- public void stpsv(String uplo, String trans, String diag, int n, float[] ap, float[] x, int incx);
- public void stpsv(String uplo, String trans, String diag, int n, float[] ap, int offsetap, float[] x, int offsetx, int incx);
-
- public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb);
- public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
-
- public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb);
- public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
-
- public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx);
- public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void strmv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx);
- public void strmv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb);
- public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
-
- public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb);
- public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
-
- public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx);
- public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void strsv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx);
- public void strsv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public int idamax(int n, double[] x, int incx);
- public int idamax(int n, double[] x, int offsetx, int incx);
-
- public int isamax(int n, float[] sx, int incx);
- public int isamax(int n, float[] sx, int offsetsx, int incx);
-
- public boolean lsame(String ca, String cb);
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java b/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java
deleted file mode 100644
index 0d3eee6d2770f81cf4194aba30f92d73186ae7be..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/InstanceBuilder.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib;
-
-import java.util.logging.Logger;
-
-final class InstanceBuilder {
-
- public static final class BLAS {
- private static final dev.ludovic.netlib.BLAS instance = getInstanceImpl();
-
- public static dev.ludovic.netlib.BLAS getInstance() {
- return instance;
- }
-
- private static dev.ludovic.netlib.BLAS getInstanceImpl() {
- try {
- return dev.ludovic.netlib.NativeBLAS.getInstance();
- } catch (Throwable t) {
- Logger.getLogger(BLAS.class.getName()).warning("Failed to load implementation from:" + dev.ludovic.netlib.NativeBLAS.class.getName());
- }
- return dev.ludovic.netlib.JavaBLAS.getInstance();
- }
- }
-
- public static final class NativeBLAS {
- private static final dev.ludovic.netlib.NativeBLAS instance = getInstanceImpl();
-
- public static dev.ludovic.netlib.NativeBLAS getInstance() {
- return instance;
- }
-
- private static dev.ludovic.netlib.NativeBLAS getInstanceImpl() {
- try {
- return dev.ludovic.netlib.blas.JNIBLAS.getInstance();
- } catch (Throwable t) {
- Logger.getLogger(NativeBLAS.class.getName()).warning("Failed to load implementation from:" + dev.ludovic.netlib.blas.JNIBLAS.class.getName());
- }
- throw new RuntimeException("Unable to load native implementation");
- }
- }
-
- public static final class JavaBLAS {
- private static final dev.ludovic.netlib.JavaBLAS instance = getInstanceImpl();
-
- public static dev.ludovic.netlib.JavaBLAS getInstance() {
- return instance;
- }
-
- private static dev.ludovic.netlib.JavaBLAS getInstanceImpl() {
- return dev.ludovic.netlib.blas.Java8BLAS.getInstance();
- }
- }
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java
deleted file mode 100644
index 834aa3986348ad8f88690ac0f0948c8c2b439452..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/JavaBLAS.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib;
-
-public interface JavaBLAS extends BLAS {
-
- public static JavaBLAS getInstance() {
- return InstanceBuilder.JavaBLAS.getInstance();
- }
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java
deleted file mode 100644
index a6fd83a4b2980a83b6e56cf08de2cfd27c165cfd..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/NativeBLAS.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib;
-
-public interface NativeBLAS extends BLAS {
-
- public static NativeBLAS getInstance() {
- return InstanceBuilder.NativeBLAS.getInstance();
- }
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java
deleted file mode 100644
index 00f3c64c5e8135c7ce7f60c67bba5785c2f5467a..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/blas/AbstractBLAS.java
+++ /dev/null
@@ -1,1689 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib.blas;
-
-import java.util.Objects;
-
-import dev.ludovic.netlib.BLAS;
-
-abstract class AbstractBLAS implements BLAS {
-
- private final static boolean debug = System.getProperty("dev.ludovic.netlib.blas.debug", "false").equals("true");
-
- protected int loopAlign(int index, int max, int size) {
- return Math.min(loopBound(index + size - 1, size), max);
- }
-
- protected int loopBound(int index, int size) {
- return index - (index % size);
- }
-
- private void checkArgument(String method, int arg, boolean check) {
- if (!check) {
- throw new IllegalArgumentException(String.format("** On entry to '%s' parameter number %d had an illegal value", method, arg));
- }
- }
-
- private void checkIndex(int index, int length) {
- //FIXME: switch to Objects.checkIndex when the minimum version becomes JDK 11
- if (index < 0 || index >= length) {
- throw new IndexOutOfBoundsException(String.format("Index %s out of bounds for length %s", index, length));
- }
- }
-
- private void requireNonNull(T obj) {
- Objects.requireNonNull(obj);
- }
-
- public double dasum(int n, double[] x, int incx) {
- if (debug) System.err.println("dasum");
- return dasum(n, x, 0, incx);
- }
-
- public double dasum(int n, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dasum");
- if (n <= 0) {
- return 0.0;
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- return dasumK(n, x, offsetx, incx);
- }
-
- protected abstract double dasumK(int n, double[] x, int offsetx, int incx);
-
- public float sasum(int n, float[] x, int incx) {
- if (debug) System.err.println("sasum");
- return sasum(n, x, 0, incx);
- }
-
- public float sasum(int n, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("sasum");
- if (n <= 0) {
- return 0.0f;
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- return sasumK(n, x, offsetx, incx);
- }
-
- protected abstract float sasumK(int n, float[] x, int offsetx, int incx);
-
- public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy) {
- if (debug) System.err.println("daxpy");
- daxpy(n, alpha, x, 0, incx, y, 0, incy);
- }
-
- // y += alpha * x
- public void daxpy(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (debug) System.err.println("daxpy");
- if (n <= 0) {
- return;
- }
- if (alpha == 0.0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- daxpyK(n, alpha, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public void saxpy(int n, float alpha, float[] x, int incx, float[] y, int incy) {
- if (debug) System.err.println("saxpy");
- saxpy(n, alpha, x, 0, incx, y, 0, incy);
- }
-
- // y += alpha * x
- public void saxpy(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (debug) System.err.println("saxpy");
- if (n <= 0) {
- return;
- }
- if (alpha == 0.0f) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- saxpyK(n, alpha, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public void dcopy(int n, double[] x, int incx, double[] y, int incy) {
- if (debug) System.err.println("dcopy");
- dcopy(n, x, 0, incx, y, 0, incy);
- }
-
- public void dcopy(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dcopy");
- if (n <= 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- dcopyK(n, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public void scopy(int n, float[] x, int incx, float[] y, int incy) {
- if (debug) System.err.println("scopy");
- scopy(n, x, 0, incx, y, 0, incy);
- }
-
- public void scopy(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (debug) System.err.println("scopy");
- if (n <= 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- scopyK(n, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public double ddot(int n, double[] x, int incx, double[] y, int incy) {
- if (debug) System.err.println("ddot");
- return ddot(n, x, 0, incx, y, 0, incy);
- }
-
- // sum(x * y)
- public double ddot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (debug) System.err.println("ddot");
- if (n <= 0) {
- return 0.0;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- return ddotK(n, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public float sdot(int n, float[] x, int incx, float[] y, int incy) {
- if (debug) System.err.println("sdot");
- return sdot(n, x, 0, incx, y, 0, incy);
- }
-
- // sum(x * y)
- public float sdot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (debug) System.err.println("sdot");
- if (n <= 0) {
- return 0.0f;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- return sdotK(n, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public float sdsdot(int n, float sb, float[] x, int incx, float[] y, int incy) {
- if (debug) System.err.println("sdsdot");
- return sdsdot(n, sb, x, 0, incx, y, 0, incy);
- }
-
- public float sdsdot(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (debug) System.err.println("sdsdot");
- if (n <= 0) {
- return 0.0f;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- return sdsdotK(n, sb, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
- if (debug) System.err.println("dgbmv");
- dgbmv(trans, m, n, kl, ku, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- public void dgbmv(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dgbmv");
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
- dgbmvK(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
- if (debug) System.err.println("sgbmv");
- sgbmv(trans, m, n, kl, ku, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- public void sgbmv(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (debug) System.err.println("sgbmv");
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
- sgbmvK(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) {
- if (debug) System.err.println("dgemm");
- dgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
- }
-
- // c = alpha * a * b + beta * c
- public void dgemm(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- if (debug) System.err.println("dgemm");
- checkArgument("DGEMM", 1, lsame("T", transa) || lsame("N", transa) || lsame("C", transa));
- checkArgument("DGEMM", 2, lsame("T", transb) || lsame("N", transb) || lsame("C", transb));
- checkArgument("DGEMM", 3, m >= 0);
- checkArgument("DGEMM", 4, n >= 0);
- checkArgument("DGEMM", 5, k >= 0);
- checkArgument("DGEMM", 8, lda >= Math.max(1, lsame("N", transa) ? m : k));
- checkArgument("DGEMM", 10, ldb >= Math.max(1, lsame("N", transb) ? k : n));
- checkArgument("DGEMM", 13, ldc >= Math.max(1, m));
- if (m == 0 || n == 0 || ((alpha == 0.0 || k == 0) && beta == 1.0)) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- requireNonNull(c);
- checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length);
- checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length);
- checkIndex(offsetc + m * n - 1, c.length);
- dgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected abstract void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) {
- if (debug) System.err.println("sgemm");
- sgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
- }
-
- // c = alpha * a * b + beta * c
- public void sgemm(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- if (debug) System.err.println("sgemm");
- checkArgument("SGEMM", 1, lsame("T", transa) || lsame("N", transa) || lsame("C", transa));
- checkArgument("SGEMM", 2, lsame("T", transb) || lsame("N", transb) || lsame("C", transb));
- checkArgument("SGEMM", 3, m >= 0);
- checkArgument("SGEMM", 4, n >= 0);
- checkArgument("SGEMM", 5, k >= 0);
- checkArgument("SGEMM", 8, lda >= Math.max(1, lsame("N", transa) ? m : k));
- checkArgument("SGEMM", 10, ldb >= Math.max(1, lsame("N", transb) ? k : n));
- checkArgument("SGEMM", 13, ldc >= Math.max(1, m));
- if (m == 0 || n == 0 || ((alpha == 0.0f || k == 0) && beta == 1.0f)) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- requireNonNull(c);
- checkIndex(offseta + (lsame("N", transa) ? k : m) * lda - 1, a.length);
- checkIndex(offsetb + (lsame("N", transb) ? n : k) * ldb - 1, b.length);
- checkIndex(offsetc + m * n - 1, c.length);
- sgemmK(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected abstract void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
-
- public void dgemv(String trans, int m, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
- if (debug) System.err.println("dgemv");
- dgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- // y = alpha * A * x + beta * y
- public void dgemv(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dgemv");
- checkArgument("DGEMV", 1, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DGEMV", 2, m >= 0);
- checkArgument("DGEMV", 3, n >= 0);
- checkArgument("DGEMV", 6, lda >= Math.max(1, m));
- checkArgument("DGEMV", 8, incx != 0);
- checkArgument("DGEMV", 11, incy != 0);
- if (m == 0 || n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
- dgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void sgemv(String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
- if (debug) System.err.println("sgemv");
- sgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- // y = alpha * A * x + beta * y
- public void sgemv(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (debug) System.err.println("sgemv");
- checkArgument("SGEMV", 1, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("SGEMV", 2, m >= 0);
- checkArgument("SGEMV", 3, n >= 0);
- checkArgument("SGEMV", 6, lda >= Math.max(1, m));
- checkArgument("SGEMV", 8, incx != 0);
- checkArgument("SGEMV", 11, incy != 0);
- if (m == 0 || n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + ((lsame("N", trans) ? n : m) - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + ((lsame("N", trans) ? m : n) - 1) * Math.abs(incy), y.length);
- sgemvK(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- // A += alpha * x * y.t
- public void dger(int m, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda) {
- if (debug) System.err.println("dger");
- dger(m, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
- }
-
- public void dger(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
- if (debug) System.err.println("dger");
- checkArgument("DGER", 1, m >= 0);
- checkArgument("DGER", 2, n >= 0);
- checkArgument("DGER", 5, incx != 0);
- checkArgument("DGER", 7, incy != 0);
- checkArgument("DGER", 9, lda >= Math.max(1, m));
- if (m == 0 || n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(a);
- checkIndex(offsetx + (m - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offseta + n * lda - 1, a.length);
- if (alpha != 0.0) {
- dgerK(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
- }
-
- protected abstract void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
-
- public void sger(int m, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda) {
- if (debug) System.err.println("sger");
- sger(m, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
- }
-
- public void sger(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
- if (debug) System.err.println("sger");
- checkArgument("SGER", 1, m >= 0);
- checkArgument("SGER", 2, n >= 0);
- checkArgument("SGER", 5, incx != 0);
- checkArgument("SGER", 7, incy != 0);
- checkArgument("SGER", 9, lda >= Math.max(1, m));
- if (m == 0 || n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(a);
- checkIndex(offsetx + (m - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offseta + n * lda - 1, a.length);
- if (alpha != 0.0f) {
- sgerK(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
- }
-
- protected abstract void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
-
- public double dnrm2(int n, double[] x, int incx) {
- if (debug) System.err.println("dnrm2");
- return dnrm2(n, x, 0, incx);
- }
-
- public double dnrm2(int n, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dnrm2");
- if (n <= 0) {
- return 0.0;
- }
- if (incx <= 0) {
- return 0.0;
- }
- if (n == 1) {
- return Math.abs(x[offsetx + 0]);
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- return dnrm2K(n, x, offsetx, incx);
- }
-
- protected abstract double dnrm2K(int n, double[] x, int offsetx, int incx);
-
- public float snrm2(int n, float[] x, int incx) {
- if (debug) System.err.println("snrm2");
- return snrm2(n, x, 0, incx);
- }
-
- public float snrm2(int n, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("snrm2");
- if (n <= 0) {
- return 0.0f;
- }
- if (incx <= 0) {
- return 0.0f;
- }
- if (n == 1) {
- return Math.abs(x[offsetx + 0]);
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- return snrm2K(n, x, offsetx, incx);
- }
-
- protected abstract float snrm2K(int n, float[] x, int offsetx, int incx);
-
- public void drot(int n, double[] x, int incx, double[] y, int incy, double c, double s) {
- if (debug) System.err.println("drot");
- drot(n, x, 0, incx, y, 0, incy, c, s);
- }
-
- public void drot(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) {
- if (debug) System.err.println("drot");
- if (n <= 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- drotK(n, x, offsetx, incx, y, offsety, incy, c, s);
- }
-
- protected abstract void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s);
-
- public void srot(int n, float[] x, int incx, float[] y, int incy, float c, float s) {
- if (debug) System.err.println("srot");
- srot(n, x, 0, incx, y, 0, incy, c, s);
- }
-
- public void srot(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) {
- if (debug) System.err.println("srot");
- if (n <= 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- srotK(n, x, offsetx, incx, y, offsety, incy, c, s);
- }
-
- protected abstract void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s);
-
- public void drotg(org.netlib.util.doubleW da, org.netlib.util.doubleW db, org.netlib.util.doubleW c, org.netlib.util.doubleW s) {
- if (debug) System.err.println("drotg");
- double scale = Math.abs(da.val) + Math.abs(db.val);
- if (scale == 0.0) {
- c.val = 1.0;
- s.val = 0.0;
- da.val = 0.0;
- db.val = 0.0;
- } else {
- double r = scale * Math.sqrt(Math.pow(da.val / scale, 2) + Math.pow(db.val / scale, 2))
- * ((Math.abs(da.val) > Math.abs(db.val) ? da.val : db.val) >= 0.0 ? 1.0 : -1.0);
- c.val = da.val / r;
- s.val = db.val / r;
- double z = 1.0;
- if (Math.abs(da.val) > Math.abs(db.val)) {
- z = s.val;
- } else if (c.val != 0.0) {
- z = 1.0 / c.val;
- }
- da.val = r;
- db.val = z;
- }
- }
-
- public void srotg(org.netlib.util.floatW sa, org.netlib.util.floatW sb, org.netlib.util.floatW c, org.netlib.util.floatW s) {
- if (debug) System.err.println("srotg");
- float scale = Math.abs(sa.val) + Math.abs(sb.val);
- if (scale == 0.0f) {
- c.val = 1.0f;
- s.val = 0.0f;
- sa.val = 0.0f;
- sb.val = 0.0f;
- } else {
- float r = (float)(scale * Math.sqrt(Math.pow(sa.val / scale, 2) + Math.pow(sb.val / scale, 2))
- * ((Math.abs(sa.val) > Math.abs(sb.val) ? sa.val : sb.val) >= 0.0f ? 1.0 : -1.0));
- c.val = sa.val / r;
- s.val = sb.val / r;
- float z = 1.0f;
- if (Math.abs(sa.val) > Math.abs(sb.val)) {
- z = s.val;
- } else if (c.val != 0.0f) {
- z = 1.0f / c.val;
- }
- sa.val = r;
- sb.val = z;
- }
- }
-
- public void drotm(int n, double[] x, int incx, double[] y, int incy, double[] param) {
- if (debug) System.err.println("drotm");
- drotm(n, x, 0, incx, y, 0, incy, param, 0);
- }
-
- public void drotm(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) {
- if (debug) System.err.println("drotm");
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(param);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offsetparam + 4, param.length); /* param.length == 5 */
- drotmK(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
- }
-
- protected abstract void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam);
-
- public void srotm(int n, float[] x, int incx, float[] y, int incy, float[] param) {
- if (debug) System.err.println("srotm");
- srotm(n, x, 0, incx, y, 0, incy, param, 0);
- }
-
- public void srotm(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) {
- if (debug) System.err.println("srotm");
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(param);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offsetparam + 4, param.length); /* param.length == 5 */
- srotmK(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
- }
-
- protected abstract void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam);
-
- public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param) {
- if (debug) System.err.println("drotmg");
- drotmg(dd1, dd2, dx1, dy1, param, 0);
- }
-
- public void drotmg(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) {
- if (debug) System.err.println("drotmg");
- requireNonNull(dd1);
- requireNonNull(dd2);
- requireNonNull(dx1);
- requireNonNull(param);
- checkIndex(offsetparam + 4, param.length);
- drotmgK(dd1, dd2, dx1, dy1, param, offsetparam);
- }
-
- protected abstract void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam);
-
- public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param) {
- if (debug) System.err.println("srotmg");
- srotmg(sd1, sd2, sx1, sy1, param, 0);
- }
-
- public void srotmg(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) {
- if (debug) System.err.println("srotmg");
- requireNonNull(sd1);
- requireNonNull(sd2);
- requireNonNull(sx1);
- requireNonNull(param);
- checkIndex(offsetparam + 4, param.length);
- srotmgK(sd1, sd2, sx1, sy1, param, offsetparam);
- }
-
- protected abstract void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam);
-
- public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
- if (debug) System.err.println("dsbmv");
- dsbmv(uplo, n, k, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- public void dsbmv(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dsbmv");
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- dsbmvK(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
- if (debug) System.err.println("ssbmv");
- ssbmv(uplo, n, k, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- public void ssbmv(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (debug) System.err.println("ssbmv");
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- ssbmvK(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dscal(int n, double alpha, double[] x, int incx) {
- if (debug) System.err.println("dscal");
- dscal(n, alpha, x, 0, incx);
- }
-
- // x = alpha * x
- public void dscal(int n, double alpha, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dscal");
- if (n <= 0) {
- return;
- }
- if (incx <= 0) {
- return;
- }
- if (alpha == 1.0) {
- return;
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dscalK(n, alpha, x, offsetx, incx);
- }
-
- protected abstract void dscalK(int n, double alpha, double[] x, int offsetx, int incx);
-
- public void sscal(int n, float alpha, float[] x, int incx) {
- if (debug) System.err.println("sscal");
- sscal(n, alpha, x, 0, incx);
- }
-
- // x = alpha * x
- public void sscal(int n, float alpha, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("sscal");
- if (n <= 0) {
- return;
- }
- if (incx <= 0) {
- return;
- }
- if (alpha == 1.0f) {
- return;
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- sscalK(n, alpha, x, offsetx, incx);
- }
-
- protected abstract void sscalK(int n, float alpha, float[] x, int offsetx, int incx);
-
- public void dspmv(String uplo, int n, double alpha, double[] a, double[] x, int incx, double beta, double[] y, int incy) {
- if (debug) System.err.println("dspmv");
- dspmv(uplo, n, alpha, a, 0, x, 0, incx, beta, y, 0, incy);
- }
-
- // y = alpha * a * x + beta * y
- public void dspmv(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dspmv");
- checkArgument("DSPMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSPMV", 2, n >= 0);
- checkArgument("DSPMV", 6, incx != 0);
- checkArgument("DSPMV", 9, incy != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- dspmvK(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void sspmv(String uplo, int n, float alpha, float[] a, float[] x, int incx, float beta, float[] y, int incy) {
- if (debug) System.err.println("sspmv");
- sspmv(uplo, n, alpha, a, 0, x, 0, incx, beta, y, 0, incy);
- }
-
- public void sspmv(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (debug) System.err.println("sspmv");
- checkArgument("SSPMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSPMV", 2, n >= 0);
- checkArgument("SSPMV", 6, incx != 0);
- checkArgument("SSPMV", 9, incy != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- sspmvK(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a) {
- if (debug) System.err.println("dspr");
- dspr(uplo, n, alpha, x, 0, incx, a, 0);
- }
-
- // a += alpha * x * x.t
- public void dspr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) {
- if (debug) System.err.println("dspr");
- checkArgument("DSPR", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSPR", 2, n >= 0);
- checkArgument("DSPR", 5, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
- dsprK(uplo, n, alpha, x, offsetx, incx, a, offseta);
- }
-
- protected abstract void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta);
-
- public void sspr(String uplo, int n, float alpha, float[] x, int incx, float[] a) {
- if (debug) System.err.println("sspr");
- sspr(uplo, n, alpha, x, 0, incx, a, 0);
- }
-
- // a += alpha * x * x.t
- public void sspr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) {
- if (debug) System.err.println("sspr");
- checkArgument("SSPR", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSPR", 2, n >= 0);
- checkArgument("SSPR", 5, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
- ssprK(uplo, n, alpha, x, offsetx, incx, a, offseta);
- }
-
- protected abstract void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta);
-
- public void dspr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a) {
- if (debug) System.err.println("dspr2");
- dspr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0);
- }
-
- // a += alpha * x * y.t + alpha * y * x.t
- public void dspr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) {
- if (debug) System.err.println("dspr2");
- checkArgument("DSPR2", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSPR2", 2, n >= 0);
- checkArgument("DSPR2", 5, incx != 0);
- checkArgument("DSPR2", 7, incy != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
- dspr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
- }
-
- protected abstract void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta);
-
- public void sspr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a) {
- if (debug) System.err.println("sspr2");
- sspr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0);
- }
-
- // a += alpha * x * y.t + alpha * y * x.t
- public void sspr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) {
- if (debug) System.err.println("sspr2");
- checkArgument("SSPR2", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSPR2", 2, n >= 0);
- checkArgument("SSPR2", 5, incx != 0);
- checkArgument("SSPR2", 7, incy != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offseta + (n * (n + 1) / 2) - 1, a.length);
- sspr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
- }
-
- protected abstract void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta);
-
- public void dswap(int n, double[] x, int incx, double[] y, int incy) {
- if (debug) System.err.println("dswap");
- dswap(n, x, 0, incx, y, 0, incy);
- }
-
- public void dswap(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dswap");
- if (n <= 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- dswapK(n, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- public void sswap(int n, float[] x, int incx, float[] y, int incy) {
- if (debug) System.err.println("sswap");
- sswap(n, x, 0, incx, y, 0, incy);
- }
-
- public void sswap(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (debug) System.err.println("sswap");
- if (n <= 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- sswapK(n, x, offsetx, incx, y, offsety, incy);
- }
-
- protected abstract void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) {
- if (debug) System.err.println("dsymm");
- dsymm(side, uplo, m, n, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
- }
-
- public void dsymm(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- if (debug) System.err.println("dsymm");
- checkArgument("DSYMM", 1, lsame("L", side) || lsame("R", side));
- checkArgument("DSYMM", 2, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSYMM", 3, m >= 0);
- checkArgument("DSYMM", 4, n >= 0);
- checkArgument("DSYMM", 7, lda >= Math.max(1, lsame("L", side) ? m : n));
- checkArgument("DSYMM", 9, ldb >= Math.max(1, m));
- checkArgument("DSYMM", 12, ldc >= Math.max(1, m));
- if (m == 0 || n == 0 || (alpha == 0.0 && beta == 1.0)) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- requireNonNull(c);
- checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
- checkIndex(offsetb + n * ldb - 1, b.length);
- checkIndex(offsetc + n * ldc - 1, c.length);
- dsymmK(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected abstract void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) {
- if (debug) System.err.println("ssymm");
- ssymm(side, uplo, m, n, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
- }
-
- public void ssymm(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- if (debug) System.err.println("ssymm");
- checkArgument("SSYMM", 1, lsame("L", side) || lsame("R", side));
- checkArgument("SSYMM", 2, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSYMM", 3, m >= 0);
- checkArgument("SSYMM", 4, n >= 0);
- checkArgument("SSYMM", 7, lda >= Math.max(1, lsame("L", side) ? m : n));
- checkArgument("SSYMM", 9, ldb >= Math.max(1, m));
- checkArgument("SSYMM", 12, ldc >= Math.max(1, m));
- if (m == 0 || n == 0 || (alpha == 0.0f && beta == 1.0f)) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- requireNonNull(c);
- checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
- checkIndex(offsetb + n * ldb - 1, b.length);
- checkIndex(offsetc + n * ldc - 1, c.length);
- ssymmK(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected abstract void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
-
- public void dsymv(String uplo, int n, double alpha, double[] a, int lda, double[] x, int incx, double beta, double[] y, int incy) {
- if (debug) System.err.println("dsymv");
- dsymv(uplo, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- public void dsymv(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (debug) System.err.println("dsymv");
- checkArgument("DSYMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSYMV", 2, n >= 0);
- checkArgument("DSYMV", 5, lda >= Math.max(1, n));
- checkArgument("DSYMV", 7, incx != 0);
- checkArgument("DSYMV", 10, incy != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- dsymvK(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- public void ssymv(String uplo, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy) {
- if (debug) System.err.println("ssymv");
- ssymv(uplo, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
- }
-
- public void ssymv(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (debug) System.err.println("ssymv");
- checkArgument("SSYMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSYMV", 2, n >= 0);
- checkArgument("SSYMV", 5, lda >= Math.max(1, n));
- checkArgument("SSYMV", 7, incx != 0);
- checkArgument("SSYMV", 10, incy != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- requireNonNull(y);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- ssymvK(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected abstract void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda) {
- if (debug) System.err.println("dsyr");
- dsyr(uplo, n, alpha, x, 0, incx, a, 0, lda);
- }
-
- // a += alpha * x * x.t
- public void dsyr(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) {
- if (debug) System.err.println("dsyr");
- checkArgument("DSYR", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSYR", 2, n >= 0);
- checkArgument("DSYR", 5, incx != 0);
- checkArgument("DSYR", 7, lda >= Math.max(1, n));
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offseta + n * lda - 1, a.length);
- dsyrK(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
- }
-
- protected abstract void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda);
-
- public void ssyr(String uplo, int n, float alpha, float[] x, int incx, float[] a, int lda) {
- if (debug) System.err.println("ssyr");
- ssyr(uplo, n, alpha, x, 0, incx, a, 0, lda);
- }
-
- public void ssyr(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) {
- if (debug) System.err.println("ssyr");
- checkArgument("SSYR", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSYR", 2, n >= 0);
- checkArgument("SSYR", 5, incx != 0);
- checkArgument("SSYR", 7, lda >= Math.max(1, n));
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offseta + n * lda - 1, a.length);
- ssyrK(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
- }
-
- protected abstract void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda);
-
- public void dsyr2(String uplo, int n, double alpha, double[] x, int incx, double[] y, int incy, double[] a, int lda) {
- if (debug) System.err.println("dsyr2");
- dsyr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
- }
-
- public void dsyr2(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
- if (debug) System.err.println("dsyr2");
- checkArgument("DSYR2", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSYR2", 2, n >= 0);
- checkArgument("DSYR2", 5, incx != 0);
- checkArgument("DSYR2", 7, incy != 0);
- checkArgument("DSYR2", 9, lda >= Math.max(1, n));
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offseta + n * lda - 1, a.length);
- dsyr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
-
- protected abstract void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
-
- public void ssyr2(String uplo, int n, float alpha, float[] x, int incx, float[] y, int incy, float[] a, int lda) {
- if (debug) System.err.println("ssyr2");
- ssyr2(uplo, n, alpha, x, 0, incx, y, 0, incy, a, 0, lda);
- }
-
- public void ssyr2(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
- if (debug) System.err.println("ssyr2");
- checkArgument("SSYR2", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSYR2", 2, n >= 0);
- checkArgument("SSYR2", 5, incx != 0);
- checkArgument("SSYR2", 7, incy != 0);
- checkArgument("SSYR2", 9, lda >= Math.max(1, n));
- if (n == 0) {
- return;
- }
- requireNonNull(x);
- requireNonNull(y);
- requireNonNull(a);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- checkIndex(offsety + (n - 1) * Math.abs(incy), y.length);
- checkIndex(offseta + n * lda - 1, a.length);
- ssyr2K(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
-
- protected abstract void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
-
- public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double[] b, int ldb, double beta, double[] c, int ldc) {
- if (debug) System.err.println("dsyr2k");
- dsyr2k(uplo, trans, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
- }
-
- public void dsyr2k(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- if (debug) System.err.println("dsyr2k");
- checkArgument("DSYR2K", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSYR2K", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DSYR2K", 3, n >= 0);
- checkArgument("DSYR2K", 4, k >= 0);
- checkArgument("DSYR2K", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
- checkArgument("DSYR2K", 9, ldb >= Math.max(1, lsame("N", trans) ? n : k));
- checkArgument("DSYR2K", 12, ldc >= Math.max(1, n));
- if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0))
- return;
- requireNonNull(a);
- requireNonNull(b);
- requireNonNull(c);
- checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
- checkIndex(offsetb + (lsame("N", trans) ? k : n) * ldb - 1, b.length);
- checkIndex(offsetc + n * ldc - 1, c.length);
- dsyr2kK(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected abstract void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float[] b, int ldb, float beta, float[] c, int ldc) {
- if (debug) System.err.println("ssyr2k");
- ssyr2k(uplo, trans, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
- }
-
- public void ssyr2k(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- if (debug) System.err.println("ssyr2k");
- checkArgument("SSYR2K", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSYR2K", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("SSYR2K", 3, n >= 0);
- checkArgument("SSYR2K", 4, k >= 0);
- checkArgument("SSYR2K", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
- checkArgument("SSYR2K", 9, ldb >= Math.max(1, lsame("N", trans) ? n : k));
- checkArgument("SSYR2K", 12, ldc >= Math.max(1, n));
- if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0f))
- return;
- requireNonNull(a);
- requireNonNull(b);
- requireNonNull(c);
- checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
- checkIndex(offsetb + (lsame("N", trans) ? k : n) * ldb - 1, b.length);
- checkIndex(offsetc + n * ldc - 1, c.length);
- ssyr2kK(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected abstract void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
-
- public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int lda, double beta, double[] c, int ldc) {
- if (debug) System.err.println("dsyrk");
- dsyrk(uplo, trans, n, k, alpha, a, 0, lda, beta, c, 0, ldc);
- }
-
- public void dsyrk(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) {
- if (debug) System.err.println("dsyrk");
- checkArgument("DSYRK", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DSYRK", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DSYRK", 3, n >= 0);
- checkArgument("DSYRK", 4, k >= 0);
- checkArgument("DSYRK", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
- checkArgument("DSYRK", 10, ldc >= Math.max(1, n));
- if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0))
- return;
- requireNonNull(a);
- requireNonNull(c);
- checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
- checkIndex(offsetc + n * ldc - 1, c.length);
- dsyrkK(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
- }
-
- protected abstract void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc);
-
- public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int lda, float beta, float[] c, int ldc) {
- if (debug) System.err.println("ssyrk");
- ssyrk(uplo, trans, n, k, alpha, a, 0, lda, beta, c, 0, ldc);
- }
-
- public void ssyrk(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) {
- if (debug) System.err.println("ssyrk");
- checkArgument("SSYRK", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("SSYRK", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("SSYRK", 3, n >= 0);
- checkArgument("SSYRK", 4, k >= 0);
- checkArgument("SSYRK", 7, lda >= Math.max(1, lsame("N", trans) ? n : k));
- checkArgument("SSYRK", 10, ldc >= Math.max(1, n));
- if (n == 0 || ((alpha == 0 || k == 0) && beta == 1.0f))
- return;
- requireNonNull(a);
- requireNonNull(c);
- checkIndex(offseta + (lsame("N", trans) ? k : n) * lda - 1, a.length);
- checkIndex(offsetc + n * ldc - 1, c.length);
- ssyrkK(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
- }
-
- protected abstract void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc);
-
- public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx) {
- if (debug) System.err.println("dtbmv");
- dtbmv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
- }
-
- public void dtbmv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dtbmv");
- checkArgument("DTBMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTBMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DTBMV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTBMV", 4, n >= 0);
- checkArgument("DTBMV", 5, k >= 0);
- checkArgument("DTBMV", 7, lda >= Math.max(1, k));
- checkArgument("DTBMV", 9, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dtbmvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx) {
- if (debug) System.err.println("stbmv");
- stbmv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
- }
-
- public void stbmv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("stbmv");
- checkArgument("STBMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STBMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("STBMV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("STBMV", 4, n >= 0);
- checkArgument("STBMV", 5, k >= 0);
- checkArgument("STBMV", 7, lda >= Math.max(1, k));
- checkArgument("STBMV", 9, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- stbmvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int lda, double[] x, int incx) {
- if (debug) System.err.println("dtbsv");
- dtbsv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
- }
-
- public void dtbsv(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dtbsv");
- checkArgument("DTBSV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTBSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DTBSV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTBSV", 4, n >= 0);
- checkArgument("DTBSV", 5, k >= 0);
- checkArgument("DTBSV", 7, lda >= Math.max(1, k));
- checkArgument("DTBSV", 9, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dtbsvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int lda, float[] x, int incx) {
- if (debug) System.err.println("stbsv");
- stbsv(uplo, trans, diag, n, k, a, 0, lda, x, 0, incx);
- }
-
- public void stbsv(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("stbsv");
- checkArgument("STBSV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STBSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("STBSV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("STBSV", 4, n >= 0);
- checkArgument("STBSV", 5, k >= 0);
- checkArgument("STBSV", 7, lda >= Math.max(1, k));
- checkArgument("STBSV", 9, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- stbsvK(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public void dtpmv(String uplo, String trans, String diag, int n, double[] a, double[] x, int incx) {
- if (debug) System.err.println("dtpmv");
- dtpmv(uplo, trans, diag, n, a, 0, x, 0, incx);
- }
-
- public void dtpmv(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dtpmv");
- checkArgument("DTPMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTPMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DTPMV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTPMV", 4, n >= 0);
- checkArgument("DTPMV", 7, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dtpmvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected abstract void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
-
- public void stpmv(String uplo, String trans, String diag, int n, float[] a, float[] x, int incx) {
- if (debug) System.err.println("stpmv");
- stpmv(uplo, trans, diag, n, a, 0, x, 0, incx);
- }
-
- public void stpmv(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("stpmv");
- checkArgument("STPMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STPMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("STPMV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("STPMV", 4, n >= 0);
- checkArgument("STPMV", 7, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- stpmvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected abstract void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
-
- public void dtpsv(String uplo, String trans, String diag, int n, double[] a, double[] x, int incx) {
- if (debug) System.err.println("dtpsv");
- dtpsv(uplo, trans, diag, n, a, 0, x, 0, incx);
- }
-
- public void dtpsv(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dtpsv");
- checkArgument("DTPSV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTPSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DTPSV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTPSV", 4, n >= 0);
- checkArgument("DTPSV", 7, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dtpsvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected abstract void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
-
- public void stpsv(String uplo, String trans, String diag, int n, float[] a, float[] x, int incx) {
- if (debug) System.err.println("stpsv");
- stpsv(uplo, trans, diag, n, a, 0, x, 0, incx);
- }
-
- public void stpsv(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("stpsv");
- checkArgument("STPSV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STPSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("STPSV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("STPSV", 4, n >= 0);
- checkArgument("STPSV", 7, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * (n + 1) / 2 - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- stpsvK(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected abstract void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
-
- public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb) {
- if (debug) System.err.println("dtrmm");
- dtrmm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
- }
-
- public void dtrmm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
- if (debug) System.err.println("dtrmm");
- checkArgument("DTRMM", 1, lsame("L", side) || lsame("R", side));
- checkArgument("DTRMM", 2, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTRMM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
- checkArgument("DTRMM", 4, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTRMM", 5, m >= 0);
- checkArgument("DTRMM", 6, n >= 0);
- checkArgument("DTRMM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
- checkArgument("DTRMM", 11, ldb >= Math.max(1, m));
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
- checkIndex(offsetb + n * ldb - 1, b.length);
- dtrmmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected abstract void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
-
- public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb) {
- if (debug) System.err.println("strmm");
- strmm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
- }
-
- public void strmm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
- if (debug) System.err.println("strmm");
- checkArgument("STRMM", 1, lsame("L", side) || lsame("R", side));
- checkArgument("STRMM", 2, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STRMM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
- checkArgument("STRMM", 4, lsame("U", diag) || lsame("N", diag));
- checkArgument("STRMM", 5, m >= 0);
- checkArgument("STRMM", 6, n >= 0);
- checkArgument("STRMM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
- checkArgument("STRMM", 11, ldb >= Math.max(1, m));
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
- checkIndex(offsetb + n * ldb - 1, b.length);
- strmmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected abstract void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
-
- public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx) {
- if (debug) System.err.println("dtrmv");
- dtrmv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
- }
-
- public void dtrmv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dtrmv");
- checkArgument("DTRMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTRMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DTRMV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTRMV", 4, n >= 0);
- checkArgument("DTRMV", 6, lda >= Math.max(1, n));
- checkArgument("DTRMV", 8, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dtrmvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void strmv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx) {
- if (debug) System.err.println("strmv");
- strmv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
- }
-
- public void strmv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("strmv");
- checkArgument("STRMV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STRMV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("STRMV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("STRMV", 4, n >= 0);
- checkArgument("STRMV", 6, lda >= Math.max(1, n));
- checkArgument("STRMV", 8, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- strmvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int lda, double[] b, int ldb) {
- if (debug) System.err.println("dtrsm");
- dtrsm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
- }
-
- public void dtrsm(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
- if (debug) System.err.println("dtrsm");
- checkArgument("DTRSM", 1, lsame("L", side) || lsame("R", side));
- checkArgument("DTRSM", 2, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTRSM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
- checkArgument("DTRSM", 4, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTRSM", 5, m >= 0);
- checkArgument("DTRSM", 6, n >= 0);
- checkArgument("DTRSM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
- checkArgument("DTRSM", 11, ldb >= Math.max(1, m));
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
- checkIndex(offsetb + n * ldb - 1, b.length);
- dtrsmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected abstract void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
-
- public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int lda, float[] b, int ldb) {
- if (debug) System.err.println("strsm");
- strsm(side, uplo, transa, diag, m, n, alpha, a, 0, lda, b, 0, ldb);
- }
-
- public void strsm(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
- if (debug) System.err.println("strsm");
- checkArgument("STRSM", 1, lsame("L", side) || lsame("R", side));
- checkArgument("STRSM", 2, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STRSM", 3, lsame("N", transa) || lsame("T", transa) || lsame("C", transa));
- checkArgument("STRSM", 4, lsame("U", diag) || lsame("N", diag));
- checkArgument("STRSM", 5, m >= 0);
- checkArgument("STRSM", 6, n >= 0);
- checkArgument("STRSM", 9, lda >= Math.max(1, lsame("L", side) ? m : n));
- checkArgument("STRSM", 11, ldb >= Math.max(1, m));
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(b);
- checkIndex(offseta + (lsame("L", side) ? m : n) * lda - 1, a.length);
- checkIndex(offsetb + n * ldb - 1, b.length);
- strsmK(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected abstract void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
-
- public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int lda, double[] x, int incx) {
- if (debug) System.err.println("dtrsv");
- dtrsv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
- }
-
- public void dtrsv(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("dtrsv");
- checkArgument("DTRSV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("DTRSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("DTRSV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("DTRSV", 4, n >= 0);
- checkArgument("DTRSV", 6, lda >= Math.max(1, n));
- checkArgument("DTRSV", 8, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- dtrsvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- public void strsv(String uplo, String trans, String diag, int n, float[] a, int lda, float[] x, int incx) {
- if (debug) System.err.println("strsv");
- strsv(uplo, trans, diag, n, a, 0, lda, x, 0, incx);
- }
-
- public void strsv(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("strsv");
- checkArgument("STRSV", 1, lsame("U", uplo) || lsame("L", uplo));
- checkArgument("STRSV", 2, lsame("N", trans) || lsame("T", trans) || lsame("C", trans));
- checkArgument("STRSV", 3, lsame("U", diag) || lsame("N", diag));
- checkArgument("STRSV", 4, n >= 0);
- checkArgument("STRSV", 6, lda >= Math.max(1, n));
- checkArgument("STRSV", 8, incx != 0);
- if (n == 0) {
- return;
- }
- requireNonNull(a);
- requireNonNull(x);
- checkIndex(offseta + n * lda - 1, a.length);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- strsvK(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected abstract void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- public int idamax(int n, double[] x, int incx) {
- if (debug) System.err.println("idamax");
- return idamax(n, x, 0, incx);
- }
-
- public int idamax(int n, double[] x, int offsetx, int incx) {
- if (debug) System.err.println("idamax");
- if (n <= 0) {
- return -1;
- }
- if (incx <= 0) {
- return -1;
- }
- if (n == 1) {
- return 0;
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- // Fortran arrays use 1-based index
- return idamaxK(n, x, offsetx, incx) - 1;
- }
-
- protected abstract int idamaxK(int n, double[] x, int offsetx, int incx);
-
- public int isamax(int n, float[] x, int incx) {
- if (debug) System.err.println("isamax");
- return isamax(n, x, 0, incx);
- }
-
- public int isamax(int n, float[] x, int offsetx, int incx) {
- if (debug) System.err.println("isamax");
- if (n <= 0) {
- return -1;
- }
- if (incx <= 0) {
- return -1;
- }
- if (n == 1) {
- return 0;
- }
- requireNonNull(x);
- checkIndex(offsetx + (n - 1) * Math.abs(incx), x.length);
- // Fortran arrays use 1-based index
- return isamaxK(n, x, offsetx, incx) - 1;
- }
-
- protected abstract int isamaxK(int n, float[] x, int offsetx, int incx);
-
- public boolean lsame(String ca, String cb) {
- if (debug) System.err.println("lsame");
- return ca != null && ca.regionMatches(true, 0, cb, 0, ca.length());
- }
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java
deleted file mode 100644
index d0a8dab6a7b62edb52772167a7cd272bb7f076be..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/blas/F2jBLAS.java
+++ /dev/null
@@ -1,241 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib.blas;
-
-import dev.ludovic.netlib.BLAS;
-
-public final class F2jBLAS extends AbstractBLAS implements dev.ludovic.netlib.JavaBLAS {
-
- private static final F2jBLAS instance = new F2jBLAS();
-
- protected F2jBLAS() {}
-
- public static dev.ludovic.netlib.JavaBLAS getInstance() {
- return instance;
- }
-
- protected double dasumK(int n, double[] x, int offsetx, int incx) {
- return org.netlib.blas.Dasum.dasum(n, x, offsetx, incx);
- }
- protected float sasumK(int n, float[] x, int offsetx, int incx) {
- return org.netlib.blas.Sasum.sasum(n, x, offsetx, incx);
- }
- protected void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- org.netlib.blas.Daxpy.daxpy(n, alpha, x, offsetx, incx, y, offsety, incy);
- }
- protected void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- org.netlib.blas.Saxpy.saxpy(n, alpha, x, offsetx, incx, y, offsety, incy);
- }
- protected void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- org.netlib.blas.Dcopy.dcopy(n, x, offsetx, incx, y, offsety, incy);
- }
- protected void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- org.netlib.blas.Scopy.scopy(n, x, offsetx, incx, y, offsety, incy);
- }
- protected double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- return org.netlib.blas.Ddot.ddot(n, x, offsetx, incx, y, offsety, incy);
- }
- protected float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- return org.netlib.blas.Sdot.sdot(n, x, offsetx, incx, y, offsety, incy);
- }
- protected float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- return org.netlib.blas.Sdsdot.sdsdot(n, sb, x, offsetx, incx, y, offsety, incy);
- }
- protected void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dgbmv.dgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Sgbmv.sgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- org.netlib.blas.Dgemm.dgemm(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- protected void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- org.netlib.blas.Sgemm.sgemm(transa, transb, m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- protected void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dgemv.dgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Sgemv.sgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
- org.netlib.blas.Dger.dger(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
- protected void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
- org.netlib.blas.Sger.sger(m, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
- protected double dnrm2K(int n, double[] x, int offsetx, int incx) {
- return org.netlib.blas.Dnrm2.dnrm2(n, x, offsetx, incx);
- }
- protected float snrm2K(int n, float[] x, int offsetx, int incx) {
- return org.netlib.blas.Snrm2.snrm2(n, x, offsetx, incx);
- }
- protected void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) {
- org.netlib.blas.Drot.drot(n, x, offsetx, incx, y, offsety, incy, c, s);
- }
- protected void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) {
- org.netlib.blas.Srot.srot(n, x, offsetx, incx, y, offsety, incy, c, s);
- }
- protected void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) {
- org.netlib.blas.Drotm.drotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
- }
- protected void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) {
- org.netlib.blas.Srotm.srotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
- }
- protected void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) {
- org.netlib.blas.Drotmg.drotmg(dd1, dd2, dx1, dy1, param, offsetparam);
- }
- protected void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) {
- org.netlib.blas.Srotmg.srotmg(sd1, sd2, sx1, sy1, param, offsetparam);
- }
- protected void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dsbmv.dsbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Ssbmv.ssbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void dscalK(int n, double alpha, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dscal.dscal(n, alpha, x, offsetx, incx);
- }
- protected void sscalK(int n, float alpha, float[] x, int offsetx, int incx) {
- org.netlib.blas.Sscal.sscal(n, alpha, x, offsetx, incx);
- }
- protected void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dspmv.dspmv(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Sspmv.sspmv(uplo, n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) {
- org.netlib.blas.Dspr.dspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
- }
- protected void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) {
- org.netlib.blas.Sspr.sspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
- }
- protected void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) {
- org.netlib.blas.Dspr2.dspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
- }
- protected void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) {
- org.netlib.blas.Sspr2.sspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
- }
- protected void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- org.netlib.blas.Dswap.dswap(n, x, offsetx, incx, y, offsety, incy);
- }
- protected void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- org.netlib.blas.Sswap.sswap(n, x, offsetx, incx, y, offsety, incy);
- }
- protected void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- org.netlib.blas.Dsymm.dsymm(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- protected void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- org.netlib.blas.Ssymm.ssymm(side, uplo, m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- protected void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dsymv.dsymv(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Ssymv.ssymv(uplo, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) {
- org.netlib.blas.Dsyr.dsyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
- }
- protected void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) {
- org.netlib.blas.Ssyr.ssyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
- }
- protected void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
- org.netlib.blas.Dsyr2.dsyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
- protected void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
- org.netlib.blas.Ssyr2.ssyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
- protected void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- org.netlib.blas.Dsyr2k.dsyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- protected void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- org.netlib.blas.Ssyr2k.ssyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- protected void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) {
- org.netlib.blas.Dsyrk.dsyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
- }
- protected void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) {
- org.netlib.blas.Ssyrk.ssyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
- }
- protected void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtbmv.dtbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
- protected void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stbmv.stbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
- protected void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtbsv.dtbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
- protected void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stbsv.stbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
- protected void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtpmv.dtpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
- protected void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stpmv.stpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
- protected void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtpsv.dtpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
- protected void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stpsv.stpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
- protected void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
- org.netlib.blas.Dtrmm.dtrmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
- protected void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
- org.netlib.blas.Strmm.strmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
- protected void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtrmv.dtrmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
- protected void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Strmv.strmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
- protected void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
- org.netlib.blas.Dtrsm.dtrsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
- protected void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
- org.netlib.blas.Strsm.strsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
- protected void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtrsv.dtrsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
- protected void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Strsv.strsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
- protected int idamaxK(int n, double[] x, int offsetx, int incx) {
- return org.netlib.blas.Idamax.idamax(n, x, offsetx, incx);
- }
- protected int isamaxK(int n, float[] x, int offsetx, int incx) {
- return org.netlib.blas.Isamax.isamax(n, x, offsetx, incx);
- }
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java
deleted file mode 100644
index 34378855fbe88de4c153ad9c1dfd7d2afb0e5805..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/blas/JNIBLAS.java
+++ /dev/null
@@ -1,201 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib.blas;
-
-import java.io.InputStream;
-import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.StandardCopyOption;
-import java.nio.file.attribute.PosixFilePermissions;
-
-public final class JNIBLAS extends AbstractBLAS implements dev.ludovic.netlib.NativeBLAS {
-
- private static final JNIBLAS instance = new JNIBLAS();
-
- protected JNIBLAS() {
- String osName = System.getProperty("os.name");
- if (osName == null || osName.isEmpty()) {
- throw new RuntimeException("Unable to load native implementation");
- }
- String osArch = System.getProperty("os.arch");
- if (osArch == null || osArch.isEmpty()) {
- throw new RuntimeException("Unable to load native implementation");
- }
-
- Path temp;
- try (InputStream resource = this.getClass().getClassLoader().getResourceAsStream(
- String.format("resources/native/%s-%s/libnetlibblasjni.so", osName, osArch))) {
- assert resource != null;
- Files.copy(resource, temp = Files.createTempFile("libnetlibblasjni.so", "",
- PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rwxr-x---"))),
- StandardCopyOption.REPLACE_EXISTING);
- temp.toFile().deleteOnExit();
- } catch (IOException e) {
- throw new RuntimeException("Unable to load native implementation", e);
- }
-
- System.load(temp.toString());
- }
-
- public static dev.ludovic.netlib.NativeBLAS getInstance() {
- return instance;
- }
-
- protected native double dasumK(int n, double[] x, int offsetx, int incx);
-
- protected native float sasumK(int n, float[] x, int offsetx, int incx);
-
- protected native void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- protected native void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- protected native void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- protected native void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- protected native double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- protected native float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- protected native float sdsdotK(int n, float sb, float[] sx, int offsetsx, int incsx, float[] sy, int offsetsy, int incsy);
-
- protected native void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- protected native void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- protected native void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- protected native void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
-
- protected native void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- protected native void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- protected native void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
-
- protected native void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
-
- protected native double dnrm2K(int n, double[] x, int offsetx, int incx);
-
- protected native float snrm2K(int n, float[] x, int offsetx, int incx);
-
- protected native void drotK(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double c, double s);
-
- protected native void srotK(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float c, float s);
-
- protected native void drotmK(int n, double[] dx, int offsetdx, int incx, double[] dy, int offsetdy, int incy, double[] dparam, int offsetdparam);
-
- protected native void srotmK(int n, float[] sx, int offsetsx, int incx, float[] sy, int offsetsy, int incy, float[] sparam, int offsetsparam);
-
- protected native void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] dparam, int offsetdparam);
-
- protected native void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] sparam, int offsetsparam);
-
- protected native void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- protected native void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- protected native void dscalK(int n, double alpha, double[] x, int offsetx, int incx);
-
- protected native void sscalK(int n, float alpha, float[] x, int offsetx, int incx);
-
- protected native void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- protected native void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- protected native void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta);
-
- protected native void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta);
-
- protected native void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta);
-
- protected native void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta);
-
- protected native void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy);
-
- protected native void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy);
-
- protected native void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- protected native void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
-
- protected native void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy);
-
- protected native void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy);
-
- protected native void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda);
-
- protected native void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda);
-
- protected native void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda);
-
- protected native void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda);
-
- protected native void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc);
-
- protected native void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc);
-
- protected native void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc);
-
- protected native void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc);
-
- protected native void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- protected native void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- protected native void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- protected native void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- protected native void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
-
- protected native void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
-
- protected native void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx);
-
- protected native void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx);
-
- protected native void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
-
- protected native void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
-
- protected native void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- protected native void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- protected native void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb);
-
- protected native void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb);
-
- protected native void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx);
-
- protected native void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx);
-
- protected native int idamaxK(int n, double[] dx, int offsetdx, int incdx);
-
- protected native int isamaxK(int n, float[] sx, int offsetsx, int incx);
-}
diff --git a/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java b/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java
deleted file mode 100644
index 443a6329f32a2cd7e70144b87328dd752230eaf9..0000000000000000000000000000000000000000
--- a/ml-core/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java
+++ /dev/null
@@ -1,5157 +0,0 @@
-/*
- * Copyright 2020, 2021, Ludovic Henry
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- *
- * Please contact git@ludovic.dev or visit ludovic.dev if you need additional
- * information or have any questions.
- */
-
-package dev.ludovic.netlib.blas;
-
-import dev.ludovic.netlib.BLAS;
-
-public class Java8BLAS extends AbstractBLAS implements dev.ludovic.netlib.JavaBLAS {
-
- private static final Java8BLAS instance = new Java8BLAS();
-
- protected Java8BLAS() {}
-
- public static dev.ludovic.netlib.JavaBLAS getInstance() {
- return instance;
- }
-
- protected double dasumK(int n, double[] x, int offsetx, int incx) {
- double sum = 0.0;
- if (incx == 1) {
- int ix = 0;
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- for (; ix < loopBound(n, 4); ix += 4) {
- sum0 += Math.abs(x[offsetx + ix + 0]);
- sum1 += Math.abs(x[offsetx + ix + 1]);
- sum2 += Math.abs(x[offsetx + ix + 2]);
- sum3 += Math.abs(x[offsetx + ix + 3]);
- }
- sum += sum0 + sum1 + sum2 + sum3;
- for (; ix < n; ix += 1) {
- sum += Math.abs(x[offsetx + ix]);
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
- sum += Math.abs(x[offsetx + ix]);
- }
- }
- return sum;
- }
-
- protected float sasumK(int n, float[] x, int offsetx, int incx) {
- float sum = 0.0f;
- if (incx == 1) {
- int ix = 0;
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- for (; ix < loopBound(n, 4); ix += 4) {
- sum0 += Math.abs(x[offsetx + ix + 0]);
- sum1 += Math.abs(x[offsetx + ix + 1]);
- sum2 += Math.abs(x[offsetx + ix + 2]);
- sum3 += Math.abs(x[offsetx + ix + 3]);
- }
- sum += sum0 + sum1 + sum2 + sum3;
- for (; ix < n; ix += 1) {
- sum += Math.abs(x[offsetx + ix]);
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
- sum += Math.abs(x[offsetx + ix]);
- }
- }
- return sum;
- }
-
- protected void daxpyK(int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (incx == 1 && incy == 1) {
- for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
- y[offsety + iy] += alpha * x[offsetx + ix];
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- y[offsety + iy] += alpha * x[offsetx + ix];
- }
- }
- }
-
- protected void saxpyK(int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (incx == 1 && incy == 1) {
- for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
- y[offsety + iy] += alpha * x[offsetx + ix];
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- y[offsety + iy] += alpha * x[offsetx + ix];
- }
- }
- }
-
- protected void dcopyK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (incx == 1 && incy == 1) {
- System.arraycopy(x, offsetx, y, offsety, n);
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- y[offsety + iy] = x[offsetx + ix];
- }
- }
- }
-
- protected void scopyK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (incx == 1 && incy == 1) {
- System.arraycopy(x, offsetx, y, offsety, n);
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- y[offsety + iy] = x[offsetx + ix];
- }
- }
- }
-
- protected double ddotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- double sum = 0.0;
- if (incx == 1 && incy == 1) {
- int ix = 0, iy = 0;
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) {
- sum0 += x[offsetx + ix + 0] * y[offsety + iy + 0];
- sum1 += x[offsetx + ix + 1] * y[offsety + iy + 1];
- sum2 += x[offsetx + ix + 2] * y[offsety + iy + 2];
- sum3 += x[offsetx + ix + 3] * y[offsety + iy + 3];
- }
- sum += sum0 + sum1 + sum2 + sum3;
- for (; ix < n && iy < n; ix += 1, iy += 1) {
- sum += x[offsetx + ix] * y[offsety + iy];
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- sum += x[offsetx + ix] * y[offsety + iy];
- }
- }
- return sum;
- }
-
- protected float sdotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- float sum = 0.0f;
- if (incx == 1 && incy == 1) {
- int ix = 0, iy = 0;
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) {
- sum0 += x[offsetx + ix + 0] * y[offsety + iy + 0];
- sum1 += x[offsetx + ix + 1] * y[offsety + iy + 1];
- sum2 += x[offsetx + ix + 2] * y[offsety + iy + 2];
- sum3 += x[offsetx + ix + 3] * y[offsety + iy + 3];
- }
- sum += sum0 + sum1 + sum2 + sum3;
- for (; ix < n && iy < n; ix += 1, iy += 1) {
- sum += x[offsetx + ix] * y[offsety + iy];
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- sum += x[offsetx + ix] * y[offsety + iy];
- }
- }
- return sum;
- }
-
- protected float sdsdotK(int n, float sb, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- double sum = sb;
- if (incx == 1 && incy == 1) {
- int ix = 0, iy = 0;
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- for (; ix < loopBound(n, 4) && iy < loopBound(n, 4); ix += 4, iy += 4) {
- sum0 += (double)x[offsetx + ix + 0] * (double)y[offsety + iy + 0];
- sum1 += (double)x[offsetx + ix + 1] * (double)y[offsety + iy + 1];
- sum2 += (double)x[offsetx + ix + 2] * (double)y[offsety + iy + 2];
- sum3 += (double)x[offsetx + ix + 3] * (double)y[offsety + iy + 3];
- }
- sum += sum0 + sum1 + sum2 + sum3;
- for (; ix < n && iy < n; ix += 1, iy += 1) {
- sum += (double)(x[offsetx + ix]) * (double)(y[offsety + iy]);
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- sum += (double)(x[offsetx + ix]) * (double)(y[offsety + iy]);
- }
- }
- return (float)sum;
- }
-
- protected void dgbmvK(String trans, int m, int n, int kl, int ku, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dgbmv.dgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- protected void sgbmvK(String trans, int m, int n, int kl, int ku, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Sgbmv.sgbmv(trans, m, n, kl, ku, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected void dgemmK(String transa, String transb, int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- if (alpha == 0.0) {
- dgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
- } else if (m * n * k < 100 * 100 * 100) {
- // The matrices are small and it's faster to do the non-copying version
- if (lsame("N", transa) && lsame("N", transb)) {
- dgemmNN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("N", transa)) {
- dgemmNT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("N", transb)) {
- dgemmTN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else {
- dgemmTT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- } else {
- final int Krow = (int)(Math.ceil((double)(Math.min(60, m)) / 3) * 3),
- Kcol = (int)(Math.ceil((double)(Math.min(1000, n)) / 3) * 3),
- Ki = (int)(Math.ceil((double)(Math.min(500, k)) / 4) * 4);
-
- assert Krow > 0;
- assert Kcol > 0;
- assert Ki > 0;
-
- double[] packeda = new double[Krow * Ki];
- double[] packedb = new double[Kcol * Ki];
- double[] packedc = new double[Kcol * Krow];
-
- // c = beta * c
- dgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
- // c += alpha * a * b
- for (int col = 0; col < n; col += Kcol) {
- int cols = col, cole = Math.min(col + Kcol, n);
- for (int i = 0; i < k; i += Ki) {
- int is = i, ie = Math.min(i + Ki, k);
- // pack b
- if (lsame("N", transb)) {
- dgecpyNN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
- } else {
- dgecpyTN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
- }
- // GEPP
- for (int row = 0; row < m; row += Krow) {
- int rows = row, rowe = Math.min(row + Krow, m);
- // pack A
- if (lsame("N", transa)) {
- dgecpyNT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
- } else {
- dgecpyTT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
- }
- // pack C
- dgecpyNN(rowe - rows, cole - cols, c, offsetc, ldc, rows, cols, packedc, 0, Krow, 0, 0);
- // GEBP
- dgebpTN(Krow, 0, rowe - rows, Kcol, 0, cole - cols, Ki, 0, ie - is,
- alpha, packeda, 0, Ki, packedb, 0, Ki, beta, packedc, 0, Krow);
- // unpack C
- dgecpyNN(rowe - rows, cole - cols, packedc, 0, Krow, 0, 0, c, offsetc, ldc, rows, cols);
- }
- }
- }
- }
- }
-
- protected void dgemmBeta(int rows, int rowe, int cols, int cole, double beta, double[] c, int offsetc, int ldc) {
- if (beta != 1.0) {
- int col = cols;
- for (; col < loopAlign(cols, cole, 4); col += 1) {
- int row = rows;
- for (; row < rowe; row += 1) {
- if (beta != 0.0) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0;
- }
- }
- }
- for (; col < loopBound(cole, 4); col += 4) {
- int row = rows;
- for (; row < rowe; row += 1) {
- if (beta != 0.0) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0;
- c[offsetc + row + (col + 1) * ldc] = 0.0;
- c[offsetc + row + (col + 2) * ldc] = 0.0;
- c[offsetc + row + (col + 3) * ldc] = 0.0;
- }
- }
- }
- for (; col < cole; col += 1) {
- int row = rows;
- for (; row < rowe; row += 1) {
- if (beta != 0.0) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0;
- }
- }
- }
- }
- }
-
- protected void dgecpyNN(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 1) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 1) * lddst, m);
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 2) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 2) * lddst, m);
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 3) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 3) * lddst, m);
- }
- for (; col < n; col += 1) {
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
- }
- }
-
- protected void dgecpyNT(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int col = 0;
- for (; col < loopBound(n, 3); col += 3) {
- int row = 0;
- for (; row < loopBound(m, 3); row += 3) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 2) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 2) * ldsrc];
- }
- for (; row < m; row += 1) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, 3); row += 3) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
- }
- for (; row < m; row += 1) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- }
- }
- }
-
- protected void dgecpyTN(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int row = 0;
- for (; row < loopBound(m, 3); row += 3) {
- int col = 0;
- for (; col < loopBound(n, 3); col += 3) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 2) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 2) * ldsrc];
- }
- for (; col < n; col += 1) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
- }
- }
- for (; row < m; row += 1) {
- int col = 0;
- for (; col < loopBound(n, 3); col += 3) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
- }
- for (; col < n; col += 1) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- }
- }
- }
-
- protected void dgecpyTT(int m, int n, double[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, double[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int row = 0;
- for (; row < loopBound(m, 4); row += 4) {
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 1) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 1) * lddst, n);
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 2) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 2) * lddst, n);
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 3) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 3) * lddst, n);
- }
- for (; row < m; row += 1) {
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
- }
- }
-
- protected void dgebpTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Tcol = 3, Trow = 3;
-
- int col = cols;
- for (; col < loopAlign(cols, cole, Tcol); col += 1) {
- int row = rows;
- for (; row < loopAlign(rows, rowe, Trow); row += 1) {
- double sum00 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- for (; row < loopBound(rowe, Trow); row += Trow) {
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
- }
- for (; row < rowe; row += 1) {
- double sum00 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- }
- for (; col < loopBound(cole, Tcol); col += Tcol) {
- int row = rows;
- for (; row < loopAlign(rows, rowe, Trow); row += 1) {
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- double sum03 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- double b1 = b[offsetb + i + (col + 1) * ldb];
- double b2 = b[offsetb + i + (col + 2) * ldb];
- sum00 = a0 * b0 + sum00;
- sum01 = a0 * b1 + sum01;
- sum02 = a0 * b2 + sum02;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
- }
- for (; row < loopBound(rowe, Trow); row += Trow) {
- dgepdotTN(m, row, row + Trow, n, col, col + Tcol, k, is, ie, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- for (; row < rowe; row += 1) {
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- double b1 = b[offsetb + i + (col + 1) * ldb];
- double b2 = b[offsetb + i + (col + 2) * ldb];
- sum00 = a0 * b0 + sum00;
- sum01 = a0 * b1 + sum01;
- sum02 = a0 * b2 + sum02;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
- }
- }
- for (; col < cole; col += 1) {
- int row = rows;
- for (; row < loopAlign(rows, rowe, Trow); row += 1) {
- double sum00 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- for (; row < loopBound(rowe, Trow); row += Trow) {
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
- }
- for (; row < rowe; row += 1) {
- double sum00 = 0.0;
- for (int i = is; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- }
- }
-
- protected void dgepdotTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Ti = 2;
-
- assert rowe - rows == 3;
- assert cole - cols == 3;
-
- int row = rows;
- int col = cols;
- int i = is;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- double sum10 = 0.0;
- double sum11 = 0.0;
- double sum12 = 0.0;
- double sum20 = 0.0;
- double sum21 = 0.0;
- double sum22 = 0.0;
- for (; i < loopAlign(is, ie, Ti); i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- double b1 = b[offsetb + i + (col + 1) * ldb];
- sum01 = a0 * b1 + sum01;
- sum11 = a1 * b1 + sum11;
- sum21 = a2 * b1 + sum21;
- double b2 = b[offsetb + i + (col + 2) * ldb];
- sum02 = a0 * b2 + sum02;
- sum12 = a1 * b2 + sum12;
- sum22 = a2 * b2 + sum22;
- }
- for (; i < loopBound(ie, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a01 = a[offseta + (i + 0) + (row + 1) * lda];
- double a02 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a01 * b00 + sum10;
- sum20 = a02 * b00 + sum20;
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- sum01 = a00 * b01 + sum01;
- sum11 = a01 * b01 + sum11;
- sum21 = a02 * b01 + sum21;
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum02 = a00 * b02 + sum02;
- sum12 = a01 * b02 + sum12;
- sum22 = a02 * b02 + sum22;
- double a10 = a[offseta + (i + 1) + (row + 0) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a12 = a[offseta + (i + 1) + (row + 2) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a10 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a12 * b10 + sum20;
- double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- sum01 = a10 * b11 + sum01;
- sum11 = a11 * b11 + sum11;
- sum21 = a12 * b11 + sum21;
- double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum02 = a10 * b12 + sum02;
- sum12 = a11 * b12 + sum12;
- sum22 = a12 * b12 + sum22;
- }
- for (; i < ie; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- double b1 = b[offsetb + i + (col + 1) * ldb];
- sum01 = a0 * b1 + sum01;
- sum11 = a1 * b1 + sum11;
- sum21 = a2 * b1 + sum21;
- double b2 = b[offsetb + i + (col + 2) * ldb];
- sum02 = a0 * b2 + sum02;
- sum12 = a1 * b2 + sum12;
- sum22 = a2 * b2 + sum22;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + c[offsetc + (row + 2) + (col + 2) * ldc];
- }
-
- protected void dgemmNN(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
-
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- double sum10 = 0.0;
- double sum11 = 0.0;
- double sum12 = 0.0;
- double sum20 = 0.0;
- double sum21 = 0.0;
- double sum22 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double a11 = a[offseta + (row + 1) + (i + 1) * lda];
- double a21 = a[offseta + (row + 2) + (i + 1) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double a11 = a[offseta + (row + 1) + (i + 1) * lda];
- double a21 = a[offseta + (row + 2) + (i + 1) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void dgemmNT(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- double sum10 = 0.0;
- double sum11 = 0.0;
- double sum12 = 0.0;
- double sum20 = 0.0;
- double sum21 = 0.0;
- double sum22 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double a11 = a[offseta + (row + 1) + (i + 1) * lda];
- double a21 = a[offseta + (row + 2) + (i + 1) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double a11 = a[offseta + (row + 1) + (i + 1) * lda];
- double a21 = a[offseta + (row + 2) + (i + 1) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double a10 = a[offseta + (row + 1) + (i + 0) * lda];
- double a20 = a[offseta + (row + 2) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- double a01 = a[offseta + (row + 0) + (i + 1) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (row + 0) + (i + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void dgemmTN(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
-
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- double sum10 = 0.0;
- double sum11 = 0.0;
- double sum12 = 0.0;
- double sum20 = 0.0;
- double sum21 = 0.0;
- double sum22 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a21 = a[offseta + (i + 1) + (row + 2) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a21 = a[offseta + (i + 1) + (row + 2) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void dgemmTT(int m, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
-
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- double sum10 = 0.0;
- double sum11 = 0.0;
- double sum12 = 0.0;
- double sum20 = 0.0;
- double sum21 = 0.0;
- double sum22 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a21 = a[offseta + (i + 1) + (row + 2) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- double sum01 = 0.0;
- double sum02 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- double b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- double b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- double b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- double b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a21 = a[offseta + (i + 1) + (row + 2) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a10 = a[offseta + (i + 0) + (row + 1) * lda];
- double a20 = a[offseta + (i + 0) + (row + 2) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- double sum00 = 0.0;
- for (; i < loopBound(k, Ti); i += Ti) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- double a01 = a[offseta + (i + 1) + (row + 0) * lda];
- double b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void sgemmK(String transa, String transb, int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- if (alpha == 0.0f) {
- sgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
- } else if (m * n * k < 100 * 100 * 100) {
- // The matrices are small and it's faster to do the non-copying version
- if (lsame("N", transa) && lsame("N", transb)) {
- sgemmNN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("N", transa)) {
- sgemmNT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("N", transb)) {
- sgemmTN(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else {
- sgemmTT(m, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- } else {
- final int Krow = (int)(Math.ceil((double)(Math.min(60, m)) / 3) * 3),
- Kcol = (int)(Math.ceil((double)(Math.min(1000, n)) / 3) * 3),
- Ki = (int)(Math.ceil((double)(Math.min(500, k)) / 4) * 4);
-
- assert Krow > 0;
- assert Kcol > 0;
- assert Ki > 0;
-
- float[] packeda = new float[Krow * Ki];
- float[] packedb = new float[Kcol * Ki];
- float[] packedc = new float[Kcol * Krow];
-
- // c = beta * c
- sgemmBeta(0, m, 0, n, beta, c, offsetc, ldc);
- // c += alpha * a * b
- for (int col = 0; col < n; col += Kcol) {
- int cols = col, cole = Math.min(col + Kcol, n);
- for (int i = 0; i < k; i += Ki) {
- int is = i, ie = Math.min(i + Ki, k);
- // pack b
- if (lsame("N", transb)) {
- sgecpyNN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
- } else {
- sgecpyTN(ie - is, cole - cols, b, offsetb, ldb, is, cols, packedb, 0, Ki, 0, 0);
- }
- // GEPP
- for (int row = 0; row < m; row += Krow) {
- int rows = row, rowe = Math.min(row + Krow, m);
- // pack A
- if (lsame("N", transa)) {
- sgecpyNT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
- } else {
- sgecpyTT(rowe - rows, ie - is, a, offseta, lda, rows, is, packeda, 0, Ki, 0, 0);
- }
- // pack C
- sgecpyNN(rowe - rows, cole - cols, c, offsetc, ldc, rows, cols, packedc, 0, Krow, 0, 0);
- // GEBP
- sgebpTN(Krow, 0, rowe - rows, Kcol, 0, cole - cols, Ki, 0, ie - is,
- alpha, packeda, 0, Ki, packedb, 0, Ki, beta, packedc, 0, Krow);
- // unpack C
- sgecpyNN(rowe - rows, cole - cols, packedc, 0, Krow, 0, 0, c, offsetc, ldc, rows, cols);
- }
- }
- }
- }
- }
-
- protected void sgemmBeta(int rows, int rowe, int cols, int cole, float beta, float[] c, int offsetc, int ldc) {
- if (beta != 1.0f) {
- int col = cols;
- for (; col < loopAlign(cols, cole, 4); col += 1) {
- int row = rows;
- for (; row < rowe; row += 1) {
- if (beta != 0.0f) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0f;
- }
- }
- }
- for (; col < loopBound(cole, 4); col += 4) {
- int row = rows;
- for (; row < rowe; row += 1) {
- if (beta != 0.0f) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0f;
- c[offsetc + row + (col + 1) * ldc] = 0.0f;
- c[offsetc + row + (col + 2) * ldc] = 0.0f;
- c[offsetc + row + (col + 3) * ldc] = 0.0f;
- }
- }
- }
- for (; col < cole; col += 1) {
- int row = rows;
- for (; row < rowe; row += 1) {
- if (beta != 0.0f) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0f;
- }
- }
- }
- }
- }
-
- protected void sgecpyNN(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 1) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 1) * lddst, m);
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 2) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 2) * lddst, m);
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 3) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 3) * lddst, m);
- }
- for (; col < n; col += 1) {
- System.arraycopy(src, offsetsrc + rowssrc + (colssrc + col + 0) * ldsrc, dst, offsetdst + rowsdst + (colsdst + col + 0) * lddst, m);
- }
- }
-
- protected void sgecpyNT(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int col = 0;
- for (; col < loopBound(n, 3); col += 3) {
- int row = 0;
- for (; row < loopBound(m, 3); row += 3) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 2) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 2) * ldsrc];
- }
- for (; row < m; row += 1) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 1) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 1) * ldsrc];
- dst[offsetdst + (colsdst + col + 2) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 2) * ldsrc];
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, 3); row += 3) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 1) * lddst] = src[offsetsrc + (rowssrc + row + 1) + (colssrc + col + 0) * ldsrc];
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 2) * lddst] = src[offsetsrc + (rowssrc + row + 2) + (colssrc + col + 0) * ldsrc];
- }
- for (; row < m; row += 1) {
- dst[offsetdst + (colsdst + col + 0) + (rowsdst + row + 0) * lddst] = src[offsetsrc + (rowssrc + row + 0) + (colssrc + col + 0) * ldsrc];
- }
- }
- }
-
- protected void sgecpyTN(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int row = 0;
- for (; row < loopBound(m, 3); row += 3) {
- int col = 0;
- for (; col < loopBound(n, 3); col += 3) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 2) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 2) * ldsrc];
- }
- for (; col < n; col += 1) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 1) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 1) * ldsrc];
- dst[offsetdst + (rowsdst + row + 2) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 2) * ldsrc];
- }
- }
- for (; row < m; row += 1) {
- int col = 0;
- for (; col < loopBound(n, 3); col += 3) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 1) * lddst] = src[offsetsrc + (colssrc + col + 1) + (rowssrc + row + 0) * ldsrc];
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 2) * lddst] = src[offsetsrc + (colssrc + col + 2) + (rowssrc + row + 0) * ldsrc];
- }
- for (; col < n; col += 1) {
- dst[offsetdst + (rowsdst + row + 0) + (colsdst + col + 0) * lddst] = src[offsetsrc + (colssrc + col + 0) + (rowssrc + row + 0) * ldsrc];
- }
- }
- }
-
- protected void sgecpyTT(int m, int n, float[] src, int offsetsrc, int ldsrc, int rowssrc, int colssrc, float[] dst, int offsetdst, int lddst, int rowsdst, int colsdst) {
- int row = 0;
- for (; row < loopBound(m, 4); row += 4) {
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 1) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 1) * lddst, n);
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 2) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 2) * lddst, n);
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 3) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 3) * lddst, n);
- }
- for (; row < m; row += 1) {
- System.arraycopy(src, offsetsrc + colssrc + (rowssrc + row + 0) * ldsrc, dst, offsetdst + colsdst + (rowsdst + row + 0) * lddst, n);
- }
- }
-
- protected void sgebpTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Tcol = 3, Trow = 3, Ti = 2;
-
- int col = cols;
- for (; col < loopAlign(cols, cole, Tcol); col += 1) {
- int row = rows;
- for (; row < loopAlign(rows, rowe, Trow); row += 1) {
- float sum00 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- for (; row < loopBound(rowe, Trow); row += Trow) {
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
- }
- for (; row < rowe; row += 1) {
- float sum00 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- }
- for (; col < loopBound(cole, Tcol); col += Tcol) {
- int row = rows;
- for (; row < loopAlign(rows, rowe, Trow); row += 1) {
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- float sum03 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- float b1 = b[offsetb + i + (col + 1) * ldb];
- float b2 = b[offsetb + i + (col + 2) * ldb];
- sum00 = a0 * b0 + sum00;
- sum01 = a0 * b1 + sum01;
- sum02 = a0 * b2 + sum02;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
- }
- for (; row < loopBound(rowe, Trow); row += Trow) {
- sgepdotTN(m, row, row + Trow, n, col, col + Tcol, k, is, ie, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- for (; row < rowe; row += 1) {
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- float b1 = b[offsetb + i + (col + 1) * ldb];
- float b2 = b[offsetb + i + (col + 2) * ldb];
- sum00 = a0 * b0 + sum00;
- sum01 = a0 * b1 + sum01;
- sum02 = a0 * b2 + sum02;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
- }
- }
- for (; col < cole; col += 1) {
- int row = rows;
- for (; row < loopAlign(rows, rowe, Trow); row += 1) {
- float sum00 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- for (; row < loopBound(rowe, Trow); row += Trow) {
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
- }
- for (; row < rowe; row += 1) {
- float sum00 = 0.0f;
- for (int i = is; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- }
- }
- }
-
- protected void sgepdotTN(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Ti = 2;
-
- assert rowe - rows == 3;
- assert cole - cols == 3;
-
- int row = rows;
- int col = cols;
- int i = is;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- float sum10 = 0.0f;
- float sum11 = 0.0f;
- float sum12 = 0.0f;
- float sum20 = 0.0f;
- float sum21 = 0.0f;
- float sum22 = 0.0f;
- for (; i < loopAlign(is, ie, Ti); i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- float b1 = b[offsetb + i + (col + 1) * ldb];
- sum01 = a0 * b1 + sum01;
- sum11 = a1 * b1 + sum11;
- sum21 = a2 * b1 + sum21;
- float b2 = b[offsetb + i + (col + 2) * ldb];
- sum02 = a0 * b2 + sum02;
- sum12 = a1 * b2 + sum12;
- sum22 = a2 * b2 + sum22;
- }
- for (; i < loopBound(ie, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a01 = a[offseta + (i + 0) + (row + 1) * lda];
- float a02 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a01 * b00 + sum10;
- sum20 = a02 * b00 + sum20;
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- sum01 = a00 * b01 + sum01;
- sum11 = a01 * b01 + sum11;
- sum21 = a02 * b01 + sum21;
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum02 = a00 * b02 + sum02;
- sum12 = a01 * b02 + sum12;
- sum22 = a02 * b02 + sum22;
- float a10 = a[offseta + (i + 1) + (row + 0) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a12 = a[offseta + (i + 1) + (row + 2) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a10 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a12 * b10 + sum20;
- float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- sum01 = a10 * b11 + sum01;
- sum11 = a11 * b11 + sum11;
- sum21 = a12 * b11 + sum21;
- float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum02 = a10 * b12 + sum02;
- sum12 = a11 * b12 + sum12;
- sum22 = a12 * b12 + sum22;
- }
- for (; i < ie; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float b0 = b[offsetb + i + (col + 0) * ldb];
- sum00 = a0 * b0 + sum00;
- sum10 = a1 * b0 + sum10;
- sum20 = a2 * b0 + sum20;
- float b1 = b[offsetb + i + (col + 1) * ldb];
- sum01 = a0 * b1 + sum01;
- sum11 = a1 * b1 + sum11;
- sum21 = a2 * b1 + sum21;
- float b2 = b[offsetb + i + (col + 2) * ldb];
- sum02 = a0 * b2 + sum02;
- sum12 = a1 * b2 + sum12;
- sum22 = a2 * b2 + sum22;
- }
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + c[offsetc + (row + 2) + (col + 2) * ldc];
- }
-
- protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
-
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- float sum10 = 0.0f;
- float sum11 = 0.0f;
- float sum12 = 0.0f;
- float sum20 = 0.0f;
- float sum21 = 0.0f;
- float sum22 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float a11 = a[offseta + (row + 1) + (i + 1) * lda];
- float a21 = a[offseta + (row + 2) + (i + 1) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float a11 = a[offseta + (row + 1) + (i + 1) * lda];
- float a21 = a[offseta + (row + 2) + (i + 1) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- float sum10 = 0.0f;
- float sum11 = 0.0f;
- float sum12 = 0.0f;
- float sum20 = 0.0f;
- float sum21 = 0.0f;
- float sum22 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float a11 = a[offseta + (row + 1) + (i + 1) * lda];
- float a21 = a[offseta + (row + 2) + (i + 1) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float a11 = a[offseta + (row + 1) + (i + 1) * lda];
- float a21 = a[offseta + (row + 2) + (i + 1) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float a10 = a[offseta + (row + 1) + (i + 0) * lda];
- float a20 = a[offseta + (row + 2) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- float a01 = a[offseta + (row + 0) + (i + 1) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (row + 0) + (i + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
-
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- float sum10 = 0.0f;
- float sum11 = 0.0f;
- float sum12 = 0.0f;
- float sum20 = 0.0f;
- float sum21 = 0.0f;
- float sum22 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a21 = a[offseta + (i + 1) + (row + 2) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a21 = a[offseta + (i + 1) + (row + 2) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Trow = 3, Tcol = 3, Ti = 2;
-
- int col = 0;
- for (; col < loopBound(n, Tcol); col += Tcol) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- float sum10 = 0.0f;
- float sum11 = 0.0f;
- float sum12 = 0.0f;
- float sum20 = 0.0f;
- float sum21 = 0.0f;
- float sum22 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a21 = a[offseta + (i + 1) + (row + 2) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- sum10 = a11 * b10 + sum10;
- sum11 = a11 * b11 + sum11;
- sum12 = a11 * b12 + sum12;
- sum20 = a21 * b10 + sum20;
- sum21 = a21 * b11 + sum21;
- sum22 = a21 * b12 + sum22;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- sum10 = a10 * b00 + sum10;
- sum11 = a10 * b01 + sum11;
- sum12 = a10 * b02 + sum12;
- sum20 = a20 * b00 + sum20;
- sum21 = a20 * b01 + sum21;
- sum22 = a20 * b02 + sum22;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- float sum01 = 0.0f;
- float sum02 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- float b11 = b[offsetb + (col + 1) + (i + 1) * ldb];
- float b12 = b[offsetb + (col + 2) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum01 = a01 * b11 + sum01;
- sum02 = a01 * b12 + sum02;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- float b01 = b[offsetb + (col + 1) + (i + 0) * ldb];
- float b02 = b[offsetb + (col + 2) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum01 = a00 * b01 + sum01;
- sum02 = a00 * b02 + sum02;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, Trow); row += Trow) {
- int i = 0;
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a21 = a[offseta + (i + 1) + (row + 2) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- sum10 = a11 * b10 + sum10;
- sum20 = a21 * b10 + sum20;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a10 = a[offseta + (i + 0) + (row + 1) * lda];
- float a20 = a[offseta + (i + 0) + (row + 2) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- sum10 = a10 * b00 + sum10;
- sum20 = a20 * b00 + sum20;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- }
- }
- for (; row < m; row += 1) {
- int i = 0;
- float sum00 = 0.0f;
- for (; i < loopBound(k, Ti); i += Ti) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- float a01 = a[offseta + (i + 1) + (row + 0) * lda];
- float b10 = b[offsetb + (col + 0) + (i + 1) * ldb];
- sum00 = a01 * b10 + sum00;
- }
- for (; i < k; i += 1) {
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float b00 = b[offsetb + (col + 0) + (i + 0) * ldb];
- sum00 = a00 * b00 + sum00;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- }
- }
- }
- }
-
- protected void dgemvK(String trans, int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (alpha == 0.0) {
- int len = lsame("N", trans) ? m : n;
- for (int i = 0, iy = incy < 0 ? (len - 1) * -incy : 0; i < len; i += 1, iy += incy) {
- if (beta != 0.0) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0;
- }
- }
- } else if (lsame("N", trans)) {
- dgemvN(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- } else if (lsame("T", trans) || lsame("C", trans)) {
- dgemvT(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- }
-
- protected void dgemvN(int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (beta != 1.0) {
- int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0;
- for (; row < m; row += 1, iy += incy) {
- if (beta != 0.0) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0;
- }
- }
- }
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4) {
- int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0;
- double alphax0 = alpha * x[offsetx + ix + incx * 0];
- double alphax1 = alpha * x[offsetx + ix + incx * 1];
- double alphax2 = alpha * x[offsetx + ix + incx * 2];
- double alphax3 = alpha * x[offsetx + ix + incx * 3];
- for (; row < m; row += 1, iy += incy) {
- y[offsety + iy] += alphax0 * a[offseta + row + (col + 0) * lda]
- + alphax1 * a[offseta + row + (col + 1) * lda]
- + alphax2 * a[offseta + row + (col + 2) * lda]
- + alphax3 * a[offseta + row + (col + 3) * lda];
- }
- }
- for (; col < n; col += 1, ix += incx) {
- int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0;
- double alphax = alpha * x[offsetx + ix];
- for (; row < m; row += 1, iy += incy) {
- y[offsety + iy] += alphax * a[offseta + row + col * lda];
- }
- }
- }
-
- protected void dgemvT(int m, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, iy += incy * 4) {
- int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0;
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- for (; row < m; row += 1, ix += incx) {
- double xix = x[offsetx + ix];
- sum0 += xix * a[offseta + row + (col + 0) * lda];
- sum1 += xix * a[offseta + row + (col + 1) * lda];
- sum2 += xix * a[offseta + row + (col + 2) * lda];
- sum3 += xix * a[offseta + row + (col + 3) * lda];
- }
- if (beta != 0.0) {
- y[offsety + iy + incy * 0] = alpha * sum0 + beta * y[offsety + iy + incy * 0];
- y[offsety + iy + incy * 1] = alpha * sum1 + beta * y[offsety + iy + incy * 1];
- y[offsety + iy + incy * 2] = alpha * sum2 + beta * y[offsety + iy + incy * 2];
- y[offsety + iy + incy * 3] = alpha * sum3 + beta * y[offsety + iy + incy * 3];
- } else {
- y[offsety + iy + incy * 0] = alpha * sum0;
- y[offsety + iy + incy * 1] = alpha * sum1;
- y[offsety + iy + incy * 2] = alpha * sum2;
- y[offsety + iy + incy * 3] = alpha * sum3;
- }
- }
- for (; col < n; col += 1, iy += incy) {
- int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0;
- double sum = 0.0;
- for (; row < m; row += 1, ix += incx) {
- sum += x[offsetx + ix] * a[offseta + row + col * lda];
- }
- if (beta != 0.0) {
- y[offsety + iy] = alpha * sum + beta * y[offsety + iy];
- } else {
- y[offsety + iy] = alpha * sum;
- }
- }
- }
-
- protected void sgemvK(String trans, int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (alpha == 0.0f) {
- int len = lsame("N", trans) ? m : n;
- for (int i = 0, iy = incy < 0 ? (len - 1) * -incy : 0; i < len; i += 1, iy += incy) {
- if (beta != 0.0f) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0f;
- }
- }
- } else if (lsame("N", trans)) {
- sgemvN(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- } else if (lsame("T", trans) || lsame("C", trans)) {
- sgemvT(m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- }
-
- protected void sgemvN(int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- // y = beta * y
- for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) {
- if (beta != 0.0f) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0f;
- }
- }
- // y += alpha * A * x
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0;
- for (; col < loopBound(n, 8); col += 8, ix += incx * 8) {
- float alphax0 = alpha * x[offsetx + ix + incx * 0];
- float alphax1 = alpha * x[offsetx + ix + incx * 1];
- float alphax2 = alpha * x[offsetx + ix + incx * 2];
- float alphax3 = alpha * x[offsetx + ix + incx * 3];
- float alphax4 = alpha * x[offsetx + ix + incx * 4];
- float alphax5 = alpha * x[offsetx + ix + incx * 5];
- float alphax6 = alpha * x[offsetx + ix + incx * 6];
- float alphax7 = alpha * x[offsetx + ix + incx * 7];
- for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) {
- y[offsety + iy] += alphax0 * a[offseta + row + (col + 0) * lda]
- + alphax1 * a[offseta + row + (col + 1) * lda]
- + alphax2 * a[offseta + row + (col + 2) * lda]
- + alphax3 * a[offseta + row + (col + 3) * lda]
- + alphax4 * a[offseta + row + (col + 4) * lda]
- + alphax5 * a[offseta + row + (col + 5) * lda]
- + alphax6 * a[offseta + row + (col + 6) * lda]
- + alphax7 * a[offseta + row + (col + 7) * lda];
- }
- }
- for (; col < n; col += 1, ix += incx) {
- float alphax = alpha * x[offsetx + ix];
- for (int row = 0, iy = incy < 0 ? (m - 1) * -incy : 0; row < m; row += 1, iy += incy) {
- y[offsety + iy] += alphax * a[offseta + row + col * lda];
- }
- }
- }
-
- protected void sgemvT(int m, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 8); col += 8, iy += incy * 8) {
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- float sum4 = 0.0f;
- float sum5 = 0.0f;
- float sum6 = 0.0f;
- float sum7 = 0.0f;
- for (int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; row < m; row += 1, ix += incx) {
- sum0 += x[offsetx + ix] * a[offseta + row + (col + 0) * lda];
- sum1 += x[offsetx + ix] * a[offseta + row + (col + 1) * lda];
- sum2 += x[offsetx + ix] * a[offseta + row + (col + 2) * lda];
- sum3 += x[offsetx + ix] * a[offseta + row + (col + 3) * lda];
- sum4 += x[offsetx + ix] * a[offseta + row + (col + 4) * lda];
- sum5 += x[offsetx + ix] * a[offseta + row + (col + 5) * lda];
- sum6 += x[offsetx + ix] * a[offseta + row + (col + 6) * lda];
- sum7 += x[offsetx + ix] * a[offseta + row + (col + 7) * lda];
- }
- if (beta != 0.0f) {
- y[offsety + iy + incy * 0] = alpha * sum0 + beta * y[offsety + iy + incy * 0];
- y[offsety + iy + incy * 1] = alpha * sum1 + beta * y[offsety + iy + incy * 1];
- y[offsety + iy + incy * 2] = alpha * sum2 + beta * y[offsety + iy + incy * 2];
- y[offsety + iy + incy * 3] = alpha * sum3 + beta * y[offsety + iy + incy * 3];
- y[offsety + iy + incy * 4] = alpha * sum4 + beta * y[offsety + iy + incy * 4];
- y[offsety + iy + incy * 5] = alpha * sum5 + beta * y[offsety + iy + incy * 5];
- y[offsety + iy + incy * 6] = alpha * sum6 + beta * y[offsety + iy + incy * 6];
- y[offsety + iy + incy * 7] = alpha * sum7 + beta * y[offsety + iy + incy * 7];
- } else {
- y[offsety + iy + incy * 0] = alpha * sum0;
- y[offsety + iy + incy * 1] = alpha * sum1;
- y[offsety + iy + incy * 2] = alpha * sum2;
- y[offsety + iy + incy * 3] = alpha * sum3;
- y[offsety + iy + incy * 4] = alpha * sum4;
- y[offsety + iy + incy * 5] = alpha * sum5;
- y[offsety + iy + incy * 6] = alpha * sum6;
- y[offsety + iy + incy * 7] = alpha * sum7;
- }
- }
- for (; col < n; col += 1, iy += incy) {
- float sum = 0.0f;
- for (int row = 0, ix = incx < 0 ? (m - 1) * -incx : 0; row < m; row += 1, ix += incx) {
- sum += x[offsetx + ix] * a[offseta + row + col * lda];
- }
- if (beta != 0.0f) {
- y[offsety + iy] = alpha * sum + beta * y[offsety + iy];
- } else {
- y[offsety + iy] = alpha * sum;
- }
- }
- }
-
- protected void dgerK(int m, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
- int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, iy += incy * 4) {
- double alphayiy0 = alpha * y[offsety + iy + incy * 0];
- double alphayiy1 = alpha * y[offsety + iy + incy * 1];
- double alphayiy2 = alpha * y[offsety + iy + incy * 2];
- double alphayiy3 = alpha * y[offsety + iy + incy * 3];
- int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
- for (; row < m; row += 1, jx += incx) {
- double xjx = x[offsetx + jx];
- a[offseta + row + (col + 0) * lda] += alphayiy0 * xjx;
- a[offseta + row + (col + 1) * lda] += alphayiy1 * xjx;
- a[offseta + row + (col + 2) * lda] += alphayiy2 * xjx;
- a[offseta + row + (col + 3) * lda] += alphayiy3 * xjx;
- }
- }
- for (; col < n; col += 1, iy += incy) {
- double alphayiy = alpha * y[offsety + iy];
- int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
- for (; row < m; row += 1, jx += incx) {
- a[offseta + row + col * lda] += alphayiy * x[offsetx + jx];
- }
- }
- }
-
- protected void sgerK(int m, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
- int col = 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, iy += incy * 4) {
- float alphayiy0 = alpha * y[offsety + iy + incy * 0];
- float alphayiy1 = alpha * y[offsety + iy + incy * 1];
- float alphayiy2 = alpha * y[offsety + iy + incy * 2];
- float alphayiy3 = alpha * y[offsety + iy + incy * 3];
- int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
- for (; row < m; row += 1, jx += incx) {
- float xjx = x[offsetx + jx];
- a[offseta + row + (col + 0) * lda] += alphayiy0 * xjx;
- a[offseta + row + (col + 1) * lda] += alphayiy1 * xjx;
- a[offseta + row + (col + 2) * lda] += alphayiy2 * xjx;
- a[offseta + row + (col + 3) * lda] += alphayiy3 * xjx;
- }
- }
- for (; col < n; col += 1, iy += incy) {
- float alphayiy = alpha * y[offsety + iy];
- int row = 0, jx = incx < 0 ? (n - 1) * -incx : 0;
- for (; row < m; row += 1, jx += incx) {
- a[offseta + row + col * lda] += alphayiy * x[offsetx + jx];
- }
- }
- }
-
- protected double dnrm2K(int n, double[] x, int offsetx, int incx) {
- int ix = 0;
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- if (incx == 1) {
- for (; ix < loopBound(n, 4); ix += 4) {
- double x0 = x[offsetx + ix + 0];
- double x1 = x[offsetx + ix + 1];
- double x2 = x[offsetx + ix + 2];
- double x3 = x[offsetx + ix + 3];
- sum0 += x0 * x0;
- sum1 += x1 * x1;
- sum2 += x2 * x2;
- sum3 += x3 * x3;
- }
- } else {
- for (; ix < loopBound(n, 4) * incx; ix += 4 * incx) {
- double x0 = x[offsetx + ix + (0 * incx)];
- double x1 = x[offsetx + ix + (1 * incx)];
- double x2 = x[offsetx + ix + (2 * incx)];
- double x3 = x[offsetx + ix + (3 * incx)];
- sum0 += x0 * x0;
- sum1 += x1 * x1;
- sum2 += x2 * x2;
- sum3 += x3 * x3;
- }
- }
- double sum = sum0 + sum1 + sum2 + sum3;
- for (; ix < n * incx; ix += incx) {
- double x0 = x[offsetx + ix + 0];
- sum += x0 * x0;
- }
- return Math.sqrt(sum);
- }
-
- protected float snrm2K(int n, float[] x, int offsetx, int incx) {
- int ix = 0;
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- if (incx == 1) {
- for (; ix < loopBound(n, 4); ix += 4) {
- float x0 = x[offsetx + ix + 0];
- float x1 = x[offsetx + ix + 1];
- float x2 = x[offsetx + ix + 2];
- float x3 = x[offsetx + ix + 3];
- sum0 += x0 * x0;
- sum1 += x1 * x1;
- sum2 += x2 * x2;
- sum3 += x3 * x3;
- }
- } else {
- for (; ix < loopBound(n, 4) * incx; ix += 4 * incx) {
- float x0 = x[offsetx + ix + (0 * incx)];
- float x1 = x[offsetx + ix + (1 * incx)];
- float x2 = x[offsetx + ix + (2 * incx)];
- float x3 = x[offsetx + ix + (3 * incx)];
- sum0 += x0 * x0;
- sum1 += x1 * x1;
- sum2 += x2 * x2;
- sum3 += x3 * x3;
- }
- }
- float sum = sum0 + sum1 + sum2 + sum3;
- for (; ix < n * incx; ix += incx) {
- float x0 = x[offsetx + ix + 0];
- sum += x0 * x0;
- }
- return (float)Math.sqrt(sum);
- }
-
- protected void drotK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double c, double s) {
- if (incx == 1 && incy == 1) {
- for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
- double x0 = x[offsetx + ix];
- double y0 = y[offsety + iy];
- x[offsetx + ix] = c * x0 + s * y0;
- y[offsety + iy] = c * y0 - s * x0;
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- double x0 = x[offsetx + ix];
- double y0 = y[offsety + iy];
- x[offsetx + ix] = c * x0 + s * y0;
- y[offsety + iy] = c * y0 - s * x0;
- }
- }
- }
-
- protected void srotK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float c, float s) {
- if (incx == 1 && incy == 1) {
- for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
- float x0 = x[offsetx + ix];
- float y0 = y[offsety + iy];
- x[offsetx + ix] = c * x0 + s * y0;
- y[offsety + iy] = c * y0 - s * x0;
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- float x0 = x[offsetx + ix];
- float y0 = y[offsety + iy];
- x[offsetx + ix] = c * x0 + s * y0;
- y[offsety + iy] = c * y0 - s * x0;
- }
- }
- }
-
- protected void drotmK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] param, int offsetparam) {
- org.netlib.blas.Drotm.drotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
- }
-
- protected void srotmK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] param, int offsetparam) {
- org.netlib.blas.Srotm.srotm(n, x, offsetx, incx, y, offsety, incy, param, offsetparam);
- }
-
- protected void drotmgK(org.netlib.util.doubleW dd1, org.netlib.util.doubleW dd2, org.netlib.util.doubleW dx1, double dy1, double[] param, int offsetparam) {
- org.netlib.blas.Drotmg.drotmg(dd1, dd2, dx1, dy1, param, offsetparam);
- }
-
- protected void srotmgK(org.netlib.util.floatW sd1, org.netlib.util.floatW sd2, org.netlib.util.floatW sx1, float sy1, float[] param, int offsetparam) {
- org.netlib.blas.Srotmg.srotmg(sd1, sd2, sx1, sy1, param, offsetparam);
- }
-
- protected void dsbmvK(String uplo, int n, int k, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- org.netlib.blas.Dsbmv.dsbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected void ssbmvK(String uplo, int n, int k, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- org.netlib.blas.Ssbmv.ssbmv(uplo, n, k, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
-
- protected void dscalK(int n, double alpha, double[] x, int offsetx, int incx) {
- if (incx == 1) {
- for (int ix = 0; ix < n; ix += 1) {
- x[offsetx + ix] *= alpha;
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
- x[offsetx + ix] *= alpha;
- }
- }
- }
-
- protected void sscalK(int n, float alpha, float[] x, int offsetx, int incx) {
- if (incx == 1) {
- for (int ix = 0; ix < n; ix += 1) {
- x[offsetx + ix] *= alpha;
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0; incx < 0 ? ix >= 0 : ix < n * incx; ix += incx) {
- x[offsetx + ix] *= alpha;
- }
- }
- }
-
- protected void dspmvK(String uplo, int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (alpha == 0.0) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0;
- }
- }
- } else if (lsame("U", uplo)) {
- dspmvU(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- } else if (lsame("L", uplo)) {
- dspmvL(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- }
- }
-
- protected void dspmvU(int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- double sumiy0 = 0.0;
- double sumiy1 = 0.0;
- double sumiy2 = 0.0;
- double sumiy3 = 0.0;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- double a0 = a[offseta + row + (col + 0) * ((col + 0) + 1) / 2];
- double a1 = a[offseta + row + (col + 1) * ((col + 1) + 1) / 2];
- double a2 = a[offseta + row + (col + 2) * ((col + 2) + 1) / 2];
- double a3 = a[offseta + row + (col + 3) * ((col + 3) + 1) / 2];
- y[offsety + jy] += alphaxix0 * a0
- + alphaxix1 * a1
- + alphaxix2 * a2
- + alphaxix3 * a3;
- double xjx = x[offsetx + jx];
- sumiy0 += xjx * a0;
- sumiy1 += xjx * a1;
- sumiy2 += xjx * a2;
- sumiy3 += xjx * a3;
- }
- double a00 = a[offseta + (row + 0) + (col + 0) * ((col + 0) + 1) / 2];
- double a01 = a[offseta + (row + 0) + (col + 1) * ((col + 1) + 1) / 2];
- double a02 = a[offseta + (row + 0) + (col + 2) * ((col + 2) + 1) / 2];
- double a03 = a[offseta + (row + 0) + (col + 3) * ((col + 3) + 1) / 2];
- double a11 = a[offseta + (row + 1) + (col + 1) * ((col + 1) + 1) / 2];
- double a12 = a[offseta + (row + 1) + (col + 2) * ((col + 2) + 1) / 2];
- double a13 = a[offseta + (row + 1) + (col + 3) * ((col + 3) + 1) / 2];
- double a22 = a[offseta + (row + 2) + (col + 2) * ((col + 2) + 1) / 2];
- double a23 = a[offseta + (row + 2) + (col + 3) * ((col + 3) + 1) / 2];
- double a33 = a[offseta + (row + 3) + (col + 3) * ((col + 3) + 1) / 2];
- double xjx0 = x[offsetx + jx + incx * 0];
- double xjx1 = x[offsetx + jx + incx * 1];
- double xjx2 = x[offsetx + jx + incx * 2];
- double xjx3 = x[offsetx + jx + incx * 3];
- sumiy0 += xjx0 * a00
- + xjx1 * a01
- + xjx2 * a02
- + xjx3 * a03;
- sumiy1 += xjx0 * a01
- + xjx1 * a11
- + xjx2 * a12
- + xjx3 * a13;
- sumiy2 += xjx0 * a02
- + xjx1 * a12
- + xjx2 * a22
- + xjx3 * a23;
- sumiy3 += xjx0 * a03
- + xjx1 * a13
- + xjx2 * a23
- + xjx3 * a33;
- if (beta != 0.0) {
- y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
- y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
- y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
- y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
- } else {
- y[offsety + iy + incy * 0] = alpha * sumiy0;
- y[offsety + iy + incy * 1] = alpha * sumiy1;
- y[offsety + iy + incy * 2] = alpha * sumiy2;
- y[offsety + iy + incy * 3] = alpha * sumiy3;
- }
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- double alphaxix = alpha * x[offsetx + ix];
- double sumiy = 0.0;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * (col + 1) / 2];
- sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
- }
- sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
- if (beta != 0.0) {
- y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
- } else {
- y[offsety + iy] = alpha * sumiy;
- }
- }
- }
-
- protected void dspmvL(int n, double alpha, double[] a, int offseta, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- // y = beta * y
- if (beta != 1.0) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0;
- }
- }
- }
- // y += alpha * A * x
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- double sumiy0 = 0.0;
- double sumiy1 = 0.0;
- double sumiy2 = 0.0;
- double sumiy3 = 0.0;
- double a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- double a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- double a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- double a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- double a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- double a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
- double a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- double a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- double a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
- double a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * (2 * n - (col + 3) - 1) / 2];
- double x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
- double x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
- double x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
- double x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
- sumiy0 += x0 * a00
- + x1 * a10
- + x2 * a20
- + x3 * a30;
- sumiy1 += x0 * a10
- + x1 * a11
- + x2 * a21
- + x3 * a31;
- sumiy2 += x0 * a20
- + x1 * a21
- + x2 * a22
- + x3 * a32;
- sumiy3 += x0 * a30
- + x1 * a31
- + x2 * a32
- + x3 * a33;
- int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- double a0 = a[offseta + row + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- double a1 = a[offseta + row + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- double a2 = a[offseta + row + (col + 2) * (2 * n - (col + 2) - 1) / 2];
- double a3 = a[offseta + row + (col + 3) * (2 * n - (col + 3) - 1) / 2];
- y[offsety + jy] += alphaxix0 * a0
- + alphaxix1 * a1
- + alphaxix2 * a2
- + alphaxix3 * a3;
- double xjx = x[offsetx + jx];
- sumiy0 += xjx * a0;
- sumiy1 += xjx * a1;
- sumiy2 += xjx * a2;
- sumiy3 += xjx * a3;
- }
- y[offsety + iy + incy * 0] += alpha * sumiy0;
- y[offsety + iy + incy * 1] += alpha * sumiy1;
- y[offsety + iy + incy * 2] += alpha * sumiy2;
- y[offsety + iy + incy * 3] += alpha * sumiy3;
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- double alphaxix = alpha * x[offsetx + ix];
- double sumiy = 0.0;
- sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * (2 * n - col - 1) / 2];
- int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * (2 * n - col - 1) / 2];
- sumiy += x[offsetx + jx] * a[offseta + row + col * (2 * n - col - 1) / 2];
- }
- y[offsety + iy] += alpha * sumiy;
- }
- }
-
- protected void sspmvK(String uplo, int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (alpha == 0.0f) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0f) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0f;
- }
- }
- } else if (lsame("U", uplo)) {
- sspmvU(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- } else if (lsame("L", uplo)) {
- sspmvL(n, alpha, a, offseta, x, offsetx, incx, beta, y, offsety, incy);
- }
- }
-
- protected void sspmvU(int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- float sumiy0 = 0.0f;
- float sumiy1 = 0.0f;
- float sumiy2 = 0.0f;
- float sumiy3 = 0.0f;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- float a0 = a[offseta + row + (col + 0) * ((col + 0) + 1) / 2];
- float a1 = a[offseta + row + (col + 1) * ((col + 1) + 1) / 2];
- float a2 = a[offseta + row + (col + 2) * ((col + 2) + 1) / 2];
- float a3 = a[offseta + row + (col + 3) * ((col + 3) + 1) / 2];
- y[offsety + jy] += alphaxix0 * a0
- + alphaxix1 * a1
- + alphaxix2 * a2
- + alphaxix3 * a3;
- float xjx = x[offsetx + jx];
- sumiy0 += xjx * a0;
- sumiy1 += xjx * a1;
- sumiy2 += xjx * a2;
- sumiy3 += xjx * a3;
- }
- float a00 = a[offseta + (row + 0) + (col + 0) * ((col + 0) + 1) / 2];
- float a01 = a[offseta + (row + 0) + (col + 1) * ((col + 1) + 1) / 2];
- float a02 = a[offseta + (row + 0) + (col + 2) * ((col + 2) + 1) / 2];
- float a03 = a[offseta + (row + 0) + (col + 3) * ((col + 3) + 1) / 2];
- float a11 = a[offseta + (row + 1) + (col + 1) * ((col + 1) + 1) / 2];
- float a12 = a[offseta + (row + 1) + (col + 2) * ((col + 2) + 1) / 2];
- float a13 = a[offseta + (row + 1) + (col + 3) * ((col + 3) + 1) / 2];
- float a22 = a[offseta + (row + 2) + (col + 2) * ((col + 2) + 1) / 2];
- float a23 = a[offseta + (row + 2) + (col + 3) * ((col + 3) + 1) / 2];
- float a33 = a[offseta + (row + 3) + (col + 3) * ((col + 3) + 1) / 2];
- float xjx0 = x[offsetx + jx + incx * 0];
- float xjx1 = x[offsetx + jx + incx * 1];
- float xjx2 = x[offsetx + jx + incx * 2];
- float xjx3 = x[offsetx + jx + incx * 3];
- sumiy0 += xjx0 * a00
- + xjx1 * a01
- + xjx2 * a02
- + xjx3 * a03;
- sumiy1 += xjx0 * a01
- + xjx1 * a11
- + xjx2 * a12
- + xjx3 * a13;
- sumiy2 += xjx0 * a02
- + xjx1 * a12
- + xjx2 * a22
- + xjx3 * a23;
- sumiy3 += xjx0 * a03
- + xjx1 * a13
- + xjx2 * a23
- + xjx3 * a33;
- if (beta != 0.0f) {
- y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
- y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
- y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
- y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
- } else {
- y[offsety + iy + incy * 0] = alpha * sumiy0;
- y[offsety + iy + incy * 1] = alpha * sumiy1;
- y[offsety + iy + incy * 2] = alpha * sumiy2;
- y[offsety + iy + incy * 3] = alpha * sumiy3;
- }
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- float alphaxix = alpha * x[offsetx + ix];
- float sumiy = 0.0f;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * (col + 1) / 2];
- sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
- }
- sumiy += x[offsetx + jx] * a[offseta + row + col * (col + 1) / 2];
- if (beta != 0.0f) {
- y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
- } else {
- y[offsety + iy] = alpha * sumiy;
- }
- }
- }
-
- protected void sspmvL(int n, float alpha, float[] a, int offseta, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- // y = beta * y
- if (beta != 1.0f) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0f) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0f;
- }
- }
- }
- // y += alpha * A * x
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- float sumiy0 = 0.0f;
- float sumiy1 = 0.0f;
- float sumiy2 = 0.0f;
- float sumiy3 = 0.0f;
- float a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- float a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- float a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- float a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- float a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- float a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
- float a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- float a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- float a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * (2 * n - (col + 2) - 1) / 2];
- float a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * (2 * n - (col + 3) - 1) / 2];
- float x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
- float x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
- float x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
- float x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
- sumiy0 += x0 * a00
- + x1 * a10
- + x2 * a20
- + x3 * a30;
- sumiy1 += x0 * a10
- + x1 * a11
- + x2 * a21
- + x3 * a31;
- sumiy2 += x0 * a20
- + x1 * a21
- + x2 * a22
- + x3 * a32;
- sumiy3 += x0 * a30
- + x1 * a31
- + x2 * a32
- + x3 * a33;
- int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- float a0 = a[offseta + row + (col + 0) * (2 * n - (col + 0) - 1) / 2];
- float a1 = a[offseta + row + (col + 1) * (2 * n - (col + 1) - 1) / 2];
- float a2 = a[offseta + row + (col + 2) * (2 * n - (col + 2) - 1) / 2];
- float a3 = a[offseta + row + (col + 3) * (2 * n - (col + 3) - 1) / 2];
- y[offsety + jy] += alphaxix0 * a0
- + alphaxix1 * a1
- + alphaxix2 * a2
- + alphaxix3 * a3;
- float xjx = x[offsetx + jx];
- sumiy0 += xjx * a0;
- sumiy1 += xjx * a1;
- sumiy2 += xjx * a2;
- sumiy3 += xjx * a3;
- }
- y[offsety + iy + incy * 0] += alpha * sumiy0;
- y[offsety + iy + incy * 1] += alpha * sumiy1;
- y[offsety + iy + incy * 2] += alpha * sumiy2;
- y[offsety + iy + incy * 3] += alpha * sumiy3;
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- float alphaxix = alpha * x[offsetx + ix];
- float sumiy = 0.0f;
- sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * (2 * n - col - 1) / 2];
- int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * (2 * n - col - 1) / 2];
- sumiy += x[offsetx + jx] * a[offseta + row + col * (2 * n - col - 1) / 2];
- }
- y[offsety + iy] += alpha * sumiy;
- }
- }
-
- protected void dsprK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta) {
- org.netlib.blas.Dspr.dspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
- }
-
- protected void ssprK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta) {
- org.netlib.blas.Sspr.sspr(uplo, n, alpha, x, offsetx, incx, a, offseta);
- }
-
- protected void dspr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta) {
- org.netlib.blas.Dspr2.dspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
- }
-
- protected void sspr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta) {
- org.netlib.blas.Sspr2.sspr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta);
- }
-
- protected void dswapK(int n, double[] x, int offsetx, int incx, double[] y, int offsety, int incy) {
- if (incx == 1 && incy == 1) {
- for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
- double tmp = y[offsety + iy];
- y[offsety + iy] = x[offsetx + ix];
- x[offsetx + ix] = tmp;
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- double tmp = y[offsety + iy];
- y[offsety + iy] = x[offsetx + ix];
- x[offsetx + ix] = tmp;
- }
- }
- }
-
- protected void sswapK(int n, float[] x, int offsetx, int incx, float[] y, int offsety, int incy) {
- if (incx == 1 && incy == 1) {
- for (int ix = 0, iy = 0; ix < n && iy < n; ix += 1, iy += 1) {
- float tmp = y[offsety + iy];
- y[offsety + iy] = x[offsetx + ix];
- x[offsetx + ix] = tmp;
- }
- } else {
- for (int ix = incx < 0 ? (n - 1) * -incx : 0,
- iy = incy < 0 ? (n - 1) * -incy : 0;
- (incx < 0 ? ix >= 0 : ix < n * incx)
- && (incy < 0 ? iy >= 0 : iy < n * incy);
- ix += incx, iy += incy) {
- float tmp = y[offsety + iy];
- y[offsety + iy] = x[offsetx + ix];
- x[offsetx + ix] = tmp;
- }
- }
- }
-
- protected void dsymmK(String side, String uplo, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- if (alpha == 0.0) {
- // C := beta*C
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- int row = 0;
- for (; row < m; row += 1) {
- if (beta != 0.0) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = 0.0;
- c[offsetc + row + (col + 1) * ldc] = 0.0;
- c[offsetc + row + (col + 2) * ldc] = 0.0;
- c[offsetc + row + (col + 3) * ldc] = 0.0;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < m; row += 1) {
- if (beta != 0.0) {
- c[offsetc + row + col * ldc] = beta * c[offsetc + row + col * ldc];
- } else {
- c[offsetc + row + col * ldc] = 0.0;
- }
- }
- }
- } else if (lsame("L", side) && lsame("U", uplo)) {
- dsymmLU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("L", side) && lsame("L", uplo)) {
- dsymmLL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("R", side) && lsame("U", uplo)) {
- dsymmRU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("R", side) && lsame("L", uplo)) {
- dsymmRL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- }
-
- protected void dsymmLU(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- // C := alpha*A*B + beta*C
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- int row = 0;
- for (; row < loopBound(m, 4); row += 4) {
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- double sum30 = 0.0;
- double sum01 = 0.0;
- double sum11 = 0.0;
- double sum21 = 0.0;
- double sum31 = 0.0;
- double sum02 = 0.0;
- double sum12 = 0.0;
- double sum22 = 0.0;
- double sum32 = 0.0;
- double sum03 = 0.0;
- double sum13 = 0.0;
- double sum23 = 0.0;
- double sum33 = 0.0;
- double alphab00 = alpha * b[offsetb + (row + 0) + (col + 0) * ldb];
- double alphab10 = alpha * b[offsetb + (row + 1) + (col + 0) * ldb];
- double alphab20 = alpha * b[offsetb + (row + 2) + (col + 0) * ldb];
- double alphab30 = alpha * b[offsetb + (row + 3) + (col + 0) * ldb];
- double alphab01 = alpha * b[offsetb + (row + 0) + (col + 1) * ldb];
- double alphab11 = alpha * b[offsetb + (row + 1) + (col + 1) * ldb];
- double alphab21 = alpha * b[offsetb + (row + 2) + (col + 1) * ldb];
- double alphab31 = alpha * b[offsetb + (row + 3) + (col + 1) * ldb];
- double alphab02 = alpha * b[offsetb + (row + 0) + (col + 2) * ldb];
- double alphab12 = alpha * b[offsetb + (row + 1) + (col + 2) * ldb];
- double alphab22 = alpha * b[offsetb + (row + 2) + (col + 2) * ldb];
- double alphab32 = alpha * b[offsetb + (row + 3) + (col + 2) * ldb];
- double alphab03 = alpha * b[offsetb + (row + 0) + (col + 3) * ldb];
- double alphab13 = alpha * b[offsetb + (row + 1) + (col + 3) * ldb];
- double alphab23 = alpha * b[offsetb + (row + 2) + (col + 3) * ldb];
- double alphab33 = alpha * b[offsetb + (row + 3) + (col + 3) * ldb];
- int i = 0;
- for (; i < row; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
- + alphab10 * a1
- + alphab20 * a2
- + alphab30 * a3;
- c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
- + alphab11 * a1
- + alphab21 * a2
- + alphab31 * a3;
- c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
- + alphab12 * a1
- + alphab22 * a2
- + alphab32 * a3;
- c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
- + alphab13 * a1
- + alphab23 * a2
- + alphab33 * a3;
- double b0 = b[offsetb + i + (col + 0) * ldb];
- double b1 = b[offsetb + i + (col + 1) * ldb];
- double b2 = b[offsetb + i + (col + 2) * ldb];
- double b3 = b[offsetb + i + (col + 3) * ldb];
- sum00 += a0 * b0;
- sum10 += a1 * b0;
- sum20 += a2 * b0;
- sum30 += a3 * b0;
- sum01 += a0 * b1;
- sum11 += a1 * b1;
- sum21 += a2 * b1;
- sum31 += a3 * b1;
- sum02 += a0 * b2;
- sum12 += a1 * b2;
- sum22 += a2 * b2;
- sum32 += a3 * b2;
- sum03 += a0 * b3;
- sum13 += a1 * b3;
- sum23 += a2 * b3;
- sum33 += a3 * b3;
- }
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a01 = a[offseta + (i + 0) + (row + 1) * lda];
- double a02 = a[offseta + (i + 0) + (row + 2) * lda];
- double a03 = a[offseta + (i + 0) + (row + 3) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a12 = a[offseta + (i + 1) + (row + 2) * lda];
- double a13 = a[offseta + (i + 1) + (row + 3) * lda];
- double a22 = a[offseta + (i + 2) + (row + 2) * lda];
- double a23 = a[offseta + (i + 2) + (row + 3) * lda];
- double a33 = a[offseta + (i + 3) + (row + 3) * lda];
- double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- double b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- double b20 = b[offsetb + (i + 2) + (col + 0) * ldb];
- double b30 = b[offsetb + (i + 3) + (col + 0) * ldb];
- double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- double b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- double b21 = b[offsetb + (i + 2) + (col + 1) * ldb];
- double b31 = b[offsetb + (i + 3) + (col + 1) * ldb];
- double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- double b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- double b22 = b[offsetb + (i + 2) + (col + 2) * ldb];
- double b32 = b[offsetb + (i + 3) + (col + 2) * ldb];
- double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
- double b13 = b[offsetb + (i + 1) + (col + 3) * ldb];
- double b23 = b[offsetb + (i + 2) + (col + 3) * ldb];
- double b33 = b[offsetb + (i + 3) + (col + 3) * ldb];
- sum00 += a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
- sum10 += a01 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
- sum20 += a02 * b00 + a12 * b10 + a22 * b20 + a23 * b30;
- sum30 += a03 * b00 + a13 * b10 + a23 * b20 + a33 * b30;
- sum01 += a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
- sum11 += a01 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
- sum21 += a02 * b01 + a12 * b11 + a22 * b21 + a23 * b31;
- sum31 += a03 * b01 + a13 * b11 + a23 * b21 + a33 * b31;
- sum02 += a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
- sum12 += a01 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
- sum22 += a02 * b02 + a12 * b12 + a22 * b22 + a23 * b32;
- sum32 += a03 * b02 + a13 * b12 + a23 * b22 + a33 * b32;
- sum03 += a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
- sum13 += a01 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
- sum23 += a02 * b03 + a12 * b13 + a22 * b23 + a23 * b33;
- sum33 += a03 * b03 + a13 * b13 + a23 * b23 + a33 * b33;
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
- }
- }
- for (; row < m; row += 1) {
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- double alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
- double alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
- double alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
- double alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
- int i = 0;
- for (; i < row; i += 1) {
- double a0 = a[offseta + i + row * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab0 * a0;
- c[offsetc + i + (col + 1) * ldc] += alphab1 * a0;
- c[offsetc + i + (col + 2) * ldc] += alphab2 * a0;
- c[offsetc + i + (col + 3) * ldc] += alphab3 * a0;
- sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
- sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
- sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
- sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
- }
- double a0 = a[offseta + i + row * lda];
- sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
- sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
- sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
- sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
- if (beta != 0.0) {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, 4); row += 4) {
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- double alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
- double alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
- double alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
- double alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
- int i = 0;
- for (; i < row; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + col * ldc] += alphab0 * a0
- + alphab1 * a1
- + alphab2 * a2
- + alphab3 * a3;
- double b0 = b[offsetb + i + col * ldb];
- sum0 += b0 * a0;
- sum1 += b0 * a1;
- sum2 += b0 * a2;
- sum3 += b0 * a3;
- }
- double a00 = a[offseta + (i + 0) + (row + 0) * lda];
- double a01 = a[offseta + (i + 0) + (row + 1) * lda];
- double a02 = a[offseta + (i + 0) + (row + 2) * lda];
- double a03 = a[offseta + (i + 0) + (row + 3) * lda];
- double a11 = a[offseta + (i + 1) + (row + 1) * lda];
- double a12 = a[offseta + (i + 1) + (row + 2) * lda];
- double a13 = a[offseta + (i + 1) + (row + 3) * lda];
- double a22 = a[offseta + (i + 2) + (row + 2) * lda];
- double a23 = a[offseta + (i + 2) + (row + 3) * lda];
- double a33 = a[offseta + (i + 3) + (row + 3) * lda];
- double b0 = b[offsetb + (i + 0) + col * ldb];
- double b1 = b[offsetb + (i + 1) + col * ldb];
- double b2 = b[offsetb + (i + 2) + col * ldb];
- double b3 = b[offsetb + (i + 3) + col * ldb];
- sum0 += b0 * a00 + b1 * a01 + b2 * a02 + b3 * a03;
- sum1 += b0 * a01 + b1 * a11 + b2 * a12 + b3 * a13;
- sum2 += b0 * a02 + b1 * a12 + b2 * a22 + b3 * a23;
- sum3 += b0 * a03 + b1 * a13 + b2 * a23 + b3 * a33;
- if (beta != 0.0) {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
- } else {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
- }
- }
- for (; row < m; row += 1) {
- double alphab = alpha * b[offsetb + row + col * ldb];
- double sum = 0.0;
- int i = 0;
- for (; i < row; i += 1) {
- double aval = a[offseta + i + row * lda];
- c[offsetc + i + col * ldc] += alphab * aval;
- sum += b[offsetb + i + col * ldb] * aval;
- }
- sum += b[offsetb + i + col * ldb] * a[offseta + i + row * lda];
- if (beta != 0.0) {
- c[offsetc + row + col * ldc] = alpha * sum + beta * c[offsetc + row + col * ldc];
- } else {
- c[offsetc + row + col * ldc] = alpha * sum;
- }
- }
- }
- }
-
- protected void dsymmLL(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- final int Srow = 4;
- // C := alpha*A*B + beta*C
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- int row = m - 1;
- for (; row >= loopBound(m - 1, Srow); row -= 1) {
- double alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
- double alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
- double alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
- double alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- sum0 += b[offsetb + row + (col + 0) * ldb] * a[offseta + row + row * lda];
- sum1 += b[offsetb + row + (col + 1) * ldb] * a[offseta + row + row * lda];
- sum2 += b[offsetb + row + (col + 2) * ldb] * a[offseta + row + row * lda];
- sum3 += b[offsetb + row + (col + 3) * ldb] * a[offseta + row + row * lda];
- int i = row + 1;
- for (; i < m; i += 1) {
- double airow = a[offseta + i + row * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab0 * airow;
- c[offsetc + i + (col + 1) * ldc] += alphab1 * airow;
- c[offsetc + i + (col + 2) * ldc] += alphab2 * airow;
- c[offsetc + i + (col + 3) * ldc] += alphab3 * airow;
- sum0 += b[offsetb + i + (col + 0) * ldb] * airow;
- sum1 += b[offsetb + i + (col + 1) * ldb] * airow;
- sum2 += b[offsetb + i + (col + 2) * ldb] * airow;
- sum3 += b[offsetb + i + (col + 3) * ldb] * airow;
- }
- if (beta != 0.0) {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
- }
- }
- for (row -= Srow - 1; row >= 0; row -= Srow) {
- double a00 = a[offseta + (row + 0) + (row + 0) * lda];
- double a10 = a[offseta + (row + 1) + (row + 0) * lda];
- double a11 = a[offseta + (row + 1) + (row + 1) * lda];
- double a20 = a[offseta + (row + 2) + (row + 0) * lda];
- double a21 = a[offseta + (row + 2) + (row + 1) * lda];
- double a22 = a[offseta + (row + 2) + (row + 2) * lda];
- double a30 = a[offseta + (row + 3) + (row + 0) * lda];
- double a31 = a[offseta + (row + 3) + (row + 1) * lda];
- double a32 = a[offseta + (row + 3) + (row + 2) * lda];
- double a33 = a[offseta + (row + 3) + (row + 3) * lda];
- double b00 = b[offsetb + (row + 0) + (col + 0) * ldb];
- double b10 = b[offsetb + (row + 1) + (col + 0) * ldb];
- double b20 = b[offsetb + (row + 2) + (col + 0) * ldb];
- double b30 = b[offsetb + (row + 3) + (col + 0) * ldb];
- double b01 = b[offsetb + (row + 0) + (col + 1) * ldb];
- double b11 = b[offsetb + (row + 1) + (col + 1) * ldb];
- double b21 = b[offsetb + (row + 2) + (col + 1) * ldb];
- double b31 = b[offsetb + (row + 3) + (col + 1) * ldb];
- double b02 = b[offsetb + (row + 0) + (col + 2) * ldb];
- double b12 = b[offsetb + (row + 1) + (col + 2) * ldb];
- double b22 = b[offsetb + (row + 2) + (col + 2) * ldb];
- double b32 = b[offsetb + (row + 3) + (col + 2) * ldb];
- double b03 = b[offsetb + (row + 0) + (col + 3) * ldb];
- double b13 = b[offsetb + (row + 1) + (col + 3) * ldb];
- double b23 = b[offsetb + (row + 2) + (col + 3) * ldb];
- double b33 = b[offsetb + (row + 3) + (col + 3) * ldb];
- double alphab00 = alpha * b00;
- double alphab10 = alpha * b10;
- double alphab20 = alpha * b20;
- double alphab30 = alpha * b30;
- double alphab01 = alpha * b01;
- double alphab11 = alpha * b11;
- double alphab21 = alpha * b21;
- double alphab31 = alpha * b31;
- double alphab02 = alpha * b02;
- double alphab12 = alpha * b12;
- double alphab22 = alpha * b22;
- double alphab32 = alpha * b32;
- double alphab03 = alpha * b03;
- double alphab13 = alpha * b13;
- double alphab23 = alpha * b23;
- double alphab33 = alpha * b33;
- double sum00 = 0.0;
- double sum10 = 0.0;
- double sum20 = 0.0;
- double sum30 = 0.0;
- double sum01 = 0.0;
- double sum11 = 0.0;
- double sum21 = 0.0;
- double sum31 = 0.0;
- double sum02 = 0.0;
- double sum12 = 0.0;
- double sum22 = 0.0;
- double sum32 = 0.0;
- double sum03 = 0.0;
- double sum13 = 0.0;
- double sum23 = 0.0;
- double sum33 = 0.0;
- sum00 += b00 * a00 + b10 * a10 + b20 * a20 + b30 * a30;
- sum10 += b00 * a10 + b10 * a11 + b20 * a21 + b30 * a31;
- sum20 += b00 * a20 + b10 * a21 + b20 * a22 + b30 * a32;
- sum30 += b00 * a30 + b10 * a31 + b20 * a32 + b30 * a33;
- sum01 += b01 * a00 + b11 * a10 + b21 * a20 + b31 * a30;
- sum11 += b01 * a10 + b11 * a11 + b21 * a21 + b31 * a31;
- sum21 += b01 * a20 + b11 * a21 + b21 * a22 + b31 * a32;
- sum31 += b01 * a30 + b11 * a31 + b21 * a32 + b31 * a33;
- sum02 += b02 * a00 + b12 * a10 + b22 * a20 + b32 * a30;
- sum12 += b02 * a10 + b12 * a11 + b22 * a21 + b32 * a31;
- sum22 += b02 * a20 + b12 * a21 + b22 * a22 + b32 * a32;
- sum32 += b02 * a30 + b12 * a31 + b22 * a32 + b32 * a33;
- sum03 += b03 * a00 + b13 * a10 + b23 * a20 + b33 * a30;
- sum13 += b03 * a10 + b13 * a11 + b23 * a21 + b33 * a31;
- sum23 += b03 * a20 + b13 * a21 + b23 * a22 + b33 * a32;
- sum33 += b03 * a30 + b13 * a31 + b23 * a32 + b33 * a33;
- int i = row + 4;
- for (; i < m; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
- + alphab10 * a1
- + alphab20 * a2
- + alphab30 * a3;
- c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
- + alphab11 * a1
- + alphab21 * a2
- + alphab31 * a3;
- c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
- + alphab12 * a1
- + alphab22 * a2
- + alphab32 * a3;
- c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
- + alphab13 * a1
- + alphab23 * a2
- + alphab33 * a3;
- double b0 = b[offsetb + i + (col + 0) * ldb];
- double b1 = b[offsetb + i + (col + 1) * ldb];
- double b2 = b[offsetb + i + (col + 2) * ldb];
- double b3 = b[offsetb + i + (col + 3) * ldb];
- sum00 += b0 * a0;
- sum10 += b0 * a1;
- sum20 += b0 * a2;
- sum30 += b0 * a3;
- sum01 += b1 * a0;
- sum11 += b1 * a1;
- sum21 += b1 * a2;
- sum31 += b1 * a3;
- sum02 += b2 * a0;
- sum12 += b2 * a1;
- sum22 += b2 * a2;
- sum32 += b2 * a3;
- sum03 += b3 * a0;
- sum13 += b3 * a1;
- sum23 += b3 * a2;
- sum33 += b3 * a3;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = m - 1;
- for (; row >= loopBound(m - 1, Srow); row -= 1) {
- double alphab0 = alpha * b[offsetb + row + col * ldb];
- double sum0 = 0.0;
- sum0 += b[offsetb + row + col * ldb] * a[offseta + row + row * lda];
- int i = row + 1;
- for (; i < m; i += 1) {
- double a0 = a[offseta + i + row * lda];
- c[offsetc + i + col * ldc] += alphab0 * a0;
- sum0 += b[offsetb + i + col * ldb] * a0;
- }
- if (beta != 0.0) {
- c[offsetc + row + col * ldc] = alpha * sum0 + beta * c[offsetc + row + col * ldc];
- } else {
- c[offsetc + row + col * ldc] = alpha * sum0;
- }
- }
- for (row -= Srow - 1; row >= 0; row -= Srow) {
- double alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
- double alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
- double alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
- double alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
- double a00 = a[offseta + (row + 0) + (row + 0) * lda];
- double a10 = a[offseta + (row + 1) + (row + 0) * lda];
- double a11 = a[offseta + (row + 1) + (row + 1) * lda];
- double a20 = a[offseta + (row + 2) + (row + 0) * lda];
- double a21 = a[offseta + (row + 2) + (row + 1) * lda];
- double a22 = a[offseta + (row + 2) + (row + 2) * lda];
- double a30 = a[offseta + (row + 3) + (row + 0) * lda];
- double a31 = a[offseta + (row + 3) + (row + 1) * lda];
- double a32 = a[offseta + (row + 3) + (row + 2) * lda];
- double a33 = a[offseta + (row + 3) + (row + 3) * lda];
- double b0 = b[offsetb + (row + 0) + col * ldb];
- double b1 = b[offsetb + (row + 1) + col * ldb];
- double b2 = b[offsetb + (row + 2) + col * ldb];
- double b3 = b[offsetb + (row + 3) + col * ldb];
- double sum0 = 0.0;
- double sum1 = 0.0;
- double sum2 = 0.0;
- double sum3 = 0.0;
- sum0 += b0 * a00 + b1 * a10 + b2 * a20 + b3 * a30;
- sum1 += b0 * a10 + b1 * a11 + b2 * a21 + b3 * a31;
- sum2 += b0 * a20 + b1 * a21 + b2 * a22 + b3 * a32;
- sum3 += b0 * a30 + b1 * a31 + b2 * a32 + b3 * a33;
- int i = row + 4;
- for (; i < m; i += 1) {
- double a0 = a[offseta + i + (row + 0) * lda];
- double a1 = a[offseta + i + (row + 1) * lda];
- double a2 = a[offseta + i + (row + 2) * lda];
- double a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + col * ldc] += alphab0 * a0
- + alphab1 * a1
- + alphab2 * a2
- + alphab3 * a3;
- double bicol = b[offsetb + i + col * ldb];
- sum0 += bicol * a0;
- sum1 += bicol * a1;
- sum2 += bicol * a2;
- sum3 += bicol * a3;
- }
- if (beta != 0.0) {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
- } else {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
- }
- }
- }
- }
-
- protected void dsymmRU(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- // C := alpha*B*A + beta*C
- org.netlib.blas.Dsymm.dsymm("R", "U", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected void dsymmRL(int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- // C := alpha*B*A + beta*C
- org.netlib.blas.Dsymm.dsymm("R", "L", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected void ssymmK(String side, String uplo, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- if (alpha == 0.0f) {
- // C := beta*C
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- int row = 0;
- for (; row < m; row += 1) {
- c[offsetc + row + (col + 0) * ldc] = beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = beta * c[offsetc + row + (col + 3) * ldc];
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < m; row += 1) {
- c[offsetc + row + col * ldc] = beta * c[offsetc + row + col * ldc];
- }
- }
- } else if (lsame("L", side) && lsame("U", uplo)) {
- ssymmLU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("L", side) && lsame("L", uplo)) {
- ssymmLL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("R", side) && lsame("U", uplo)) {
- ssymmRU(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- } else if (lsame("R", side) && lsame("L", uplo)) {
- ssymmRL(m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
- }
-
- protected void ssymmLU(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- // C := alpha*A*B + beta*C
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- int row = 0;
- for (; row < loopBound(m, 4); row += 4) {
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- float sum30 = 0.0f;
- float sum01 = 0.0f;
- float sum11 = 0.0f;
- float sum21 = 0.0f;
- float sum31 = 0.0f;
- float sum02 = 0.0f;
- float sum12 = 0.0f;
- float sum22 = 0.0f;
- float sum32 = 0.0f;
- float sum03 = 0.0f;
- float sum13 = 0.0f;
- float sum23 = 0.0f;
- float sum33 = 0.0f;
- float alphab00 = alpha * b[offsetb + (row + 0) + (col + 0) * ldb];
- float alphab10 = alpha * b[offsetb + (row + 1) + (col + 0) * ldb];
- float alphab20 = alpha * b[offsetb + (row + 2) + (col + 0) * ldb];
- float alphab30 = alpha * b[offsetb + (row + 3) + (col + 0) * ldb];
- float alphab01 = alpha * b[offsetb + (row + 0) + (col + 1) * ldb];
- float alphab11 = alpha * b[offsetb + (row + 1) + (col + 1) * ldb];
- float alphab21 = alpha * b[offsetb + (row + 2) + (col + 1) * ldb];
- float alphab31 = alpha * b[offsetb + (row + 3) + (col + 1) * ldb];
- float alphab02 = alpha * b[offsetb + (row + 0) + (col + 2) * ldb];
- float alphab12 = alpha * b[offsetb + (row + 1) + (col + 2) * ldb];
- float alphab22 = alpha * b[offsetb + (row + 2) + (col + 2) * ldb];
- float alphab32 = alpha * b[offsetb + (row + 3) + (col + 2) * ldb];
- float alphab03 = alpha * b[offsetb + (row + 0) + (col + 3) * ldb];
- float alphab13 = alpha * b[offsetb + (row + 1) + (col + 3) * ldb];
- float alphab23 = alpha * b[offsetb + (row + 2) + (col + 3) * ldb];
- float alphab33 = alpha * b[offsetb + (row + 3) + (col + 3) * ldb];
- int i = 0;
- for (; i < row; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
- + alphab10 * a1
- + alphab20 * a2
- + alphab30 * a3;
- c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
- + alphab11 * a1
- + alphab21 * a2
- + alphab31 * a3;
- c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
- + alphab12 * a1
- + alphab22 * a2
- + alphab32 * a3;
- c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
- + alphab13 * a1
- + alphab23 * a2
- + alphab33 * a3;
- float b0 = b[offsetb + i + (col + 0) * ldb];
- float b1 = b[offsetb + i + (col + 1) * ldb];
- float b2 = b[offsetb + i + (col + 2) * ldb];
- float b3 = b[offsetb + i + (col + 3) * ldb];
- sum00 += a0 * b0;
- sum10 += a1 * b0;
- sum20 += a2 * b0;
- sum30 += a3 * b0;
- sum01 += a0 * b1;
- sum11 += a1 * b1;
- sum21 += a2 * b1;
- sum31 += a3 * b1;
- sum02 += a0 * b2;
- sum12 += a1 * b2;
- sum22 += a2 * b2;
- sum32 += a3 * b2;
- sum03 += a0 * b3;
- sum13 += a1 * b3;
- sum23 += a2 * b3;
- sum33 += a3 * b3;
- }
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a01 = a[offseta + (i + 0) + (row + 1) * lda];
- float a02 = a[offseta + (i + 0) + (row + 2) * lda];
- float a03 = a[offseta + (i + 0) + (row + 3) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a12 = a[offseta + (i + 1) + (row + 2) * lda];
- float a13 = a[offseta + (i + 1) + (row + 3) * lda];
- float a22 = a[offseta + (i + 2) + (row + 2) * lda];
- float a23 = a[offseta + (i + 2) + (row + 3) * lda];
- float a33 = a[offseta + (i + 3) + (row + 3) * lda];
- float b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
- float b10 = b[offsetb + (i + 1) + (col + 0) * ldb];
- float b20 = b[offsetb + (i + 2) + (col + 0) * ldb];
- float b30 = b[offsetb + (i + 3) + (col + 0) * ldb];
- float b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
- float b11 = b[offsetb + (i + 1) + (col + 1) * ldb];
- float b21 = b[offsetb + (i + 2) + (col + 1) * ldb];
- float b31 = b[offsetb + (i + 3) + (col + 1) * ldb];
- float b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
- float b12 = b[offsetb + (i + 1) + (col + 2) * ldb];
- float b22 = b[offsetb + (i + 2) + (col + 2) * ldb];
- float b32 = b[offsetb + (i + 3) + (col + 2) * ldb];
- float b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
- float b13 = b[offsetb + (i + 1) + (col + 3) * ldb];
- float b23 = b[offsetb + (i + 2) + (col + 3) * ldb];
- float b33 = b[offsetb + (i + 3) + (col + 3) * ldb];
- sum00 += a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
- sum10 += a01 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
- sum20 += a02 * b00 + a12 * b10 + a22 * b20 + a23 * b30;
- sum30 += a03 * b00 + a13 * b10 + a23 * b20 + a33 * b30;
- sum01 += a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
- sum11 += a01 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
- sum21 += a02 * b01 + a12 * b11 + a22 * b21 + a23 * b31;
- sum31 += a03 * b01 + a13 * b11 + a23 * b21 + a33 * b31;
- sum02 += a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
- sum12 += a01 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
- sum22 += a02 * b02 + a12 * b12 + a22 * b22 + a23 * b32;
- sum32 += a03 * b02 + a13 * b12 + a23 * b22 + a33 * b32;
- sum03 += a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
- sum13 += a01 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
- sum23 += a02 * b03 + a12 * b13 + a22 * b23 + a23 * b33;
- sum33 += a03 * b03 + a13 * b13 + a23 * b23 + a33 * b33;
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
- }
- }
- for (; row < m; row += 1) {
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- float alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
- float alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
- float alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
- float alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
- int i = 0;
- for (; i < row; i += 1) {
- float a0 = a[offseta + i + row * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab0 * a0;
- c[offsetc + i + (col + 1) * ldc] += alphab1 * a0;
- c[offsetc + i + (col + 2) * ldc] += alphab2 * a0;
- c[offsetc + i + (col + 3) * ldc] += alphab3 * a0;
- sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
- sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
- sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
- sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
- }
- float a0 = a[offseta + i + row * lda];
- sum0 += b[offsetb + i + (col + 0) * ldb] * a0;
- sum1 += b[offsetb + i + (col + 1) * ldb] * a0;
- sum2 += b[offsetb + i + (col + 2) * ldb] * a0;
- sum3 += b[offsetb + i + (col + 3) * ldb] * a0;
- if (beta != 0.0f) {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = 0;
- for (; row < loopBound(m, 4); row += 4) {
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- float alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
- float alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
- float alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
- float alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
- int i = 0;
- for (; i < row; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + col * ldc] += alphab0 * a0
- + alphab1 * a1
- + alphab2 * a2
- + alphab3 * a3;
- float b0 = b[offsetb + i + col * ldb];
- sum0 += b0 * a0;
- sum1 += b0 * a1;
- sum2 += b0 * a2;
- sum3 += b0 * a3;
- }
- float a00 = a[offseta + (i + 0) + (row + 0) * lda];
- float a01 = a[offseta + (i + 0) + (row + 1) * lda];
- float a02 = a[offseta + (i + 0) + (row + 2) * lda];
- float a03 = a[offseta + (i + 0) + (row + 3) * lda];
- float a11 = a[offseta + (i + 1) + (row + 1) * lda];
- float a12 = a[offseta + (i + 1) + (row + 2) * lda];
- float a13 = a[offseta + (i + 1) + (row + 3) * lda];
- float a22 = a[offseta + (i + 2) + (row + 2) * lda];
- float a23 = a[offseta + (i + 2) + (row + 3) * lda];
- float a33 = a[offseta + (i + 3) + (row + 3) * lda];
- float b0 = b[offsetb + (i + 0) + col * ldb];
- float b1 = b[offsetb + (i + 1) + col * ldb];
- float b2 = b[offsetb + (i + 2) + col * ldb];
- float b3 = b[offsetb + (i + 3) + col * ldb];
- sum0 += b0 * a00 + b1 * a01 + b2 * a02 + b3 * a03;
- sum1 += b0 * a01 + b1 * a11 + b2 * a12 + b3 * a13;
- sum2 += b0 * a02 + b1 * a12 + b2 * a22 + b3 * a23;
- sum3 += b0 * a03 + b1 * a13 + b2 * a23 + b3 * a33;
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
- } else {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
- }
- }
- for (; row < m; row += 1) {
- float alphab = alpha * b[offsetb + row + col * ldb];
- float sum = 0.0f;
- int i = 0;
- for (; i < row; i += 1) {
- float aval = a[offseta + i + row * lda];
- c[offsetc + i + col * ldc] += alphab * aval;
- sum += b[offsetb + i + col * ldb] * aval;
- }
- sum += b[offsetb + i + col * ldb] * a[offseta + i + row * lda];
- if (beta != 0.0f) {
- c[offsetc + row + col * ldc] = alpha * sum + beta * c[offsetc + row + col * ldc];
- } else {
- c[offsetc + row + col * ldc] = alpha * sum;
- }
- }
- }
- }
-
- protected void ssymmLL(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- final int Srow = 4;
- // C := alpha*A*B + beta*C
- int col = 0;
- for (; col < loopBound(n, 4); col += 4) {
- int row = m - 1;
- for (; row >= loopBound(m - 1, Srow); row -= 1) {
- float alphab0 = alpha * b[offsetb + row + (col + 0) * ldb];
- float alphab1 = alpha * b[offsetb + row + (col + 1) * ldb];
- float alphab2 = alpha * b[offsetb + row + (col + 2) * ldb];
- float alphab3 = alpha * b[offsetb + row + (col + 3) * ldb];
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- sum0 += b[offsetb + row + (col + 0) * ldb] * a[offseta + row + row * lda];
- sum1 += b[offsetb + row + (col + 1) * ldb] * a[offseta + row + row * lda];
- sum2 += b[offsetb + row + (col + 2) * ldb] * a[offseta + row + row * lda];
- sum3 += b[offsetb + row + (col + 3) * ldb] * a[offseta + row + row * lda];
- int i = row + 1;
- for (; i < m; i += 1) {
- float airow = a[offseta + i + row * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab0 * airow;
- c[offsetc + i + (col + 1) * ldc] += alphab1 * airow;
- c[offsetc + i + (col + 2) * ldc] += alphab2 * airow;
- c[offsetc + i + (col + 3) * ldc] += alphab3 * airow;
- sum0 += b[offsetb + i + (col + 0) * ldb] * airow;
- sum1 += b[offsetb + i + (col + 1) * ldb] * airow;
- sum2 += b[offsetb + i + (col + 2) * ldb] * airow;
- sum3 += b[offsetb + i + (col + 3) * ldb] * airow;
- }
- if (beta != 0.0f) {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0 + beta * c[offsetc + row + (col + 0) * ldc];
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1 + beta * c[offsetc + row + (col + 1) * ldc];
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2 + beta * c[offsetc + row + (col + 2) * ldc];
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3 + beta * c[offsetc + row + (col + 3) * ldc];
- } else {
- c[offsetc + row + (col + 0) * ldc] = alpha * sum0;
- c[offsetc + row + (col + 1) * ldc] = alpha * sum1;
- c[offsetc + row + (col + 2) * ldc] = alpha * sum2;
- c[offsetc + row + (col + 3) * ldc] = alpha * sum3;
- }
- }
- for (row -= Srow - 1; row >= 0; row -= Srow) {
- float a00 = a[offseta + (row + 0) + (row + 0) * lda];
- float a10 = a[offseta + (row + 1) + (row + 0) * lda];
- float a11 = a[offseta + (row + 1) + (row + 1) * lda];
- float a20 = a[offseta + (row + 2) + (row + 0) * lda];
- float a21 = a[offseta + (row + 2) + (row + 1) * lda];
- float a22 = a[offseta + (row + 2) + (row + 2) * lda];
- float a30 = a[offseta + (row + 3) + (row + 0) * lda];
- float a31 = a[offseta + (row + 3) + (row + 1) * lda];
- float a32 = a[offseta + (row + 3) + (row + 2) * lda];
- float a33 = a[offseta + (row + 3) + (row + 3) * lda];
- float b00 = b[offsetb + (row + 0) + (col + 0) * ldb];
- float b10 = b[offsetb + (row + 1) + (col + 0) * ldb];
- float b20 = b[offsetb + (row + 2) + (col + 0) * ldb];
- float b30 = b[offsetb + (row + 3) + (col + 0) * ldb];
- float b01 = b[offsetb + (row + 0) + (col + 1) * ldb];
- float b11 = b[offsetb + (row + 1) + (col + 1) * ldb];
- float b21 = b[offsetb + (row + 2) + (col + 1) * ldb];
- float b31 = b[offsetb + (row + 3) + (col + 1) * ldb];
- float b02 = b[offsetb + (row + 0) + (col + 2) * ldb];
- float b12 = b[offsetb + (row + 1) + (col + 2) * ldb];
- float b22 = b[offsetb + (row + 2) + (col + 2) * ldb];
- float b32 = b[offsetb + (row + 3) + (col + 2) * ldb];
- float b03 = b[offsetb + (row + 0) + (col + 3) * ldb];
- float b13 = b[offsetb + (row + 1) + (col + 3) * ldb];
- float b23 = b[offsetb + (row + 2) + (col + 3) * ldb];
- float b33 = b[offsetb + (row + 3) + (col + 3) * ldb];
- float alphab00 = alpha * b00;
- float alphab10 = alpha * b10;
- float alphab20 = alpha * b20;
- float alphab30 = alpha * b30;
- float alphab01 = alpha * b01;
- float alphab11 = alpha * b11;
- float alphab21 = alpha * b21;
- float alphab31 = alpha * b31;
- float alphab02 = alpha * b02;
- float alphab12 = alpha * b12;
- float alphab22 = alpha * b22;
- float alphab32 = alpha * b32;
- float alphab03 = alpha * b03;
- float alphab13 = alpha * b13;
- float alphab23 = alpha * b23;
- float alphab33 = alpha * b33;
- float sum00 = 0.0f;
- float sum10 = 0.0f;
- float sum20 = 0.0f;
- float sum30 = 0.0f;
- float sum01 = 0.0f;
- float sum11 = 0.0f;
- float sum21 = 0.0f;
- float sum31 = 0.0f;
- float sum02 = 0.0f;
- float sum12 = 0.0f;
- float sum22 = 0.0f;
- float sum32 = 0.0f;
- float sum03 = 0.0f;
- float sum13 = 0.0f;
- float sum23 = 0.0f;
- float sum33 = 0.0f;
- sum00 += b00 * a00 + b10 * a10 + b20 * a20 + b30 * a30;
- sum10 += b00 * a10 + b10 * a11 + b20 * a21 + b30 * a31;
- sum20 += b00 * a20 + b10 * a21 + b20 * a22 + b30 * a32;
- sum30 += b00 * a30 + b10 * a31 + b20 * a32 + b30 * a33;
- sum01 += b01 * a00 + b11 * a10 + b21 * a20 + b31 * a30;
- sum11 += b01 * a10 + b11 * a11 + b21 * a21 + b31 * a31;
- sum21 += b01 * a20 + b11 * a21 + b21 * a22 + b31 * a32;
- sum31 += b01 * a30 + b11 * a31 + b21 * a32 + b31 * a33;
- sum02 += b02 * a00 + b12 * a10 + b22 * a20 + b32 * a30;
- sum12 += b02 * a10 + b12 * a11 + b22 * a21 + b32 * a31;
- sum22 += b02 * a20 + b12 * a21 + b22 * a22 + b32 * a32;
- sum32 += b02 * a30 + b12 * a31 + b22 * a32 + b32 * a33;
- sum03 += b03 * a00 + b13 * a10 + b23 * a20 + b33 * a30;
- sum13 += b03 * a10 + b13 * a11 + b23 * a21 + b33 * a31;
- sum23 += b03 * a20 + b13 * a21 + b23 * a22 + b33 * a32;
- sum33 += b03 * a30 + b13 * a31 + b23 * a32 + b33 * a33;
- int i = row + 4;
- for (; i < m; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + (col + 0) * ldc] += alphab00 * a0
- + alphab10 * a1
- + alphab20 * a2
- + alphab30 * a3;
- c[offsetc + i + (col + 1) * ldc] += alphab01 * a0
- + alphab11 * a1
- + alphab21 * a2
- + alphab31 * a3;
- c[offsetc + i + (col + 2) * ldc] += alphab02 * a0
- + alphab12 * a1
- + alphab22 * a2
- + alphab32 * a3;
- c[offsetc + i + (col + 3) * ldc] += alphab03 * a0
- + alphab13 * a1
- + alphab23 * a2
- + alphab33 * a3;
- float b0 = b[offsetb + i + (col + 0) * ldb];
- float b1 = b[offsetb + i + (col + 1) * ldb];
- float b2 = b[offsetb + i + (col + 2) * ldb];
- float b3 = b[offsetb + i + (col + 3) * ldb];
- sum00 += b0 * a0;
- sum10 += b0 * a1;
- sum20 += b0 * a2;
- sum30 += b0 * a3;
- sum01 += b1 * a0;
- sum11 += b1 * a1;
- sum21 += b1 * a2;
- sum31 += b1 * a3;
- sum02 += b2 * a0;
- sum12 += b2 * a1;
- sum22 += b2 * a2;
- sum32 += b2 * a3;
- sum03 += b3 * a0;
- sum13 += b3 * a1;
- sum23 += b3 * a2;
- sum33 += b3 * a3;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00 + beta * c[offsetc + (row + 0) + (col + 0) * ldc];
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10 + beta * c[offsetc + (row + 1) + (col + 0) * ldc];
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20 + beta * c[offsetc + (row + 2) + (col + 0) * ldc];
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30 + beta * c[offsetc + (row + 3) + (col + 0) * ldc];
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01 + beta * c[offsetc + (row + 0) + (col + 1) * ldc];
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11 + beta * c[offsetc + (row + 1) + (col + 1) * ldc];
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21 + beta * c[offsetc + (row + 2) + (col + 1) * ldc];
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31 + beta * c[offsetc + (row + 3) + (col + 1) * ldc];
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02 + beta * c[offsetc + (row + 0) + (col + 2) * ldc];
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12 + beta * c[offsetc + (row + 1) + (col + 2) * ldc];
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22 + beta * c[offsetc + (row + 2) + (col + 2) * ldc];
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32 + beta * c[offsetc + (row + 3) + (col + 2) * ldc];
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03 + beta * c[offsetc + (row + 0) + (col + 3) * ldc];
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13 + beta * c[offsetc + (row + 1) + (col + 3) * ldc];
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23 + beta * c[offsetc + (row + 2) + (col + 3) * ldc];
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33 + beta * c[offsetc + (row + 3) + (col + 3) * ldc];
- } else {
- c[offsetc + (row + 0) + (col + 0) * ldc] = alpha * sum00;
- c[offsetc + (row + 1) + (col + 0) * ldc] = alpha * sum10;
- c[offsetc + (row + 2) + (col + 0) * ldc] = alpha * sum20;
- c[offsetc + (row + 3) + (col + 0) * ldc] = alpha * sum30;
- c[offsetc + (row + 0) + (col + 1) * ldc] = alpha * sum01;
- c[offsetc + (row + 1) + (col + 1) * ldc] = alpha * sum11;
- c[offsetc + (row + 2) + (col + 1) * ldc] = alpha * sum21;
- c[offsetc + (row + 3) + (col + 1) * ldc] = alpha * sum31;
- c[offsetc + (row + 0) + (col + 2) * ldc] = alpha * sum02;
- c[offsetc + (row + 1) + (col + 2) * ldc] = alpha * sum12;
- c[offsetc + (row + 2) + (col + 2) * ldc] = alpha * sum22;
- c[offsetc + (row + 3) + (col + 2) * ldc] = alpha * sum32;
- c[offsetc + (row + 0) + (col + 3) * ldc] = alpha * sum03;
- c[offsetc + (row + 1) + (col + 3) * ldc] = alpha * sum13;
- c[offsetc + (row + 2) + (col + 3) * ldc] = alpha * sum23;
- c[offsetc + (row + 3) + (col + 3) * ldc] = alpha * sum33;
- }
- }
- }
- for (; col < n; col += 1) {
- int row = m - 1;
- for (; row >= loopBound(m - 1, Srow); row -= 1) {
- float alphab0 = alpha * b[offsetb + row + col * ldb];
- float sum0 = 0.0f;
- sum0 += b[offsetb + row + col * ldb] * a[offseta + row + row * lda];
- int i = row + 1;
- for (; i < m; i += 1) {
- float a0 = a[offseta + i + row * lda];
- c[offsetc + i + col * ldc] += alphab0 * a0;
- sum0 += b[offsetb + i + col * ldb] * a0;
- }
- if (beta != 0.0f) {
- c[offsetc + row + col * ldc] = alpha * sum0 + beta * c[offsetc + row + col * ldc];
- } else {
- c[offsetc + row + col * ldc] = alpha * sum0;
- }
- }
- for (row -= Srow - 1; row >= 0; row -= Srow) {
- float alphab0 = alpha * b[offsetb + (row + 0) + col * ldb];
- float alphab1 = alpha * b[offsetb + (row + 1) + col * ldb];
- float alphab2 = alpha * b[offsetb + (row + 2) + col * ldb];
- float alphab3 = alpha * b[offsetb + (row + 3) + col * ldb];
- float a00 = a[offseta + (row + 0) + (row + 0) * lda];
- float a10 = a[offseta + (row + 1) + (row + 0) * lda];
- float a11 = a[offseta + (row + 1) + (row + 1) * lda];
- float a20 = a[offseta + (row + 2) + (row + 0) * lda];
- float a21 = a[offseta + (row + 2) + (row + 1) * lda];
- float a22 = a[offseta + (row + 2) + (row + 2) * lda];
- float a30 = a[offseta + (row + 3) + (row + 0) * lda];
- float a31 = a[offseta + (row + 3) + (row + 1) * lda];
- float a32 = a[offseta + (row + 3) + (row + 2) * lda];
- float a33 = a[offseta + (row + 3) + (row + 3) * lda];
- float b0 = b[offsetb + (row + 0) + col * ldb];
- float b1 = b[offsetb + (row + 1) + col * ldb];
- float b2 = b[offsetb + (row + 2) + col * ldb];
- float b3 = b[offsetb + (row + 3) + col * ldb];
- float sum0 = 0.0f;
- float sum1 = 0.0f;
- float sum2 = 0.0f;
- float sum3 = 0.0f;
- sum0 += b0 * a00 + b1 * a10 + b2 * a20 + b3 * a30;
- sum1 += b0 * a10 + b1 * a11 + b2 * a21 + b3 * a31;
- sum2 += b0 * a20 + b1 * a21 + b2 * a22 + b3 * a32;
- sum3 += b0 * a30 + b1 * a31 + b2 * a32 + b3 * a33;
- int i = row + 4;
- for (; i < m; i += 1) {
- float a0 = a[offseta + i + (row + 0) * lda];
- float a1 = a[offseta + i + (row + 1) * lda];
- float a2 = a[offseta + i + (row + 2) * lda];
- float a3 = a[offseta + i + (row + 3) * lda];
- c[offsetc + i + col * ldc] += alphab0 * a0
- + alphab1 * a1
- + alphab2 * a2
- + alphab3 * a3;
- float bicol = b[offsetb + i + col * ldb];
- sum0 += bicol * a0;
- sum1 += bicol * a1;
- sum2 += bicol * a2;
- sum3 += bicol * a3;
- }
- if (beta != 0.0f) {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0 + beta * c[offsetc + (row + 0) + col * ldc];
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1 + beta * c[offsetc + (row + 1) + col * ldc];
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2 + beta * c[offsetc + (row + 2) + col * ldc];
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3 + beta * c[offsetc + (row + 3) + col * ldc];
- } else {
- c[offsetc + (row + 0) + col * ldc] = alpha * sum0;
- c[offsetc + (row + 1) + col * ldc] = alpha * sum1;
- c[offsetc + (row + 2) + col * ldc] = alpha * sum2;
- c[offsetc + (row + 3) + col * ldc] = alpha * sum3;
- }
- }
- }
- }
-
- protected void ssymmRU(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- // C := alpha*B*A + beta*C
- org.netlib.blas.Ssymm.ssymm("R", "U", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected void ssymmRL(int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- // C := alpha*B*A + beta*C
- org.netlib.blas.Ssymm.ssymm("R", "L", m, n, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected void dsymvK(String uplo, int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- if (alpha == 0.0) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0;
- }
- }
- } else if (lsame("U", uplo)) {
- dsymvU(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- } else if (lsame("L", uplo)) {
- dsymvL(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- }
-
- protected void dsymvU(int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- double sumiy0 = 0.0;
- double sumiy1 = 0.0;
- double sumiy2 = 0.0;
- double sumiy3 = 0.0;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- double a0 = a[offseta + row + (col + 0) * lda];
- double a1 = a[offseta + row + (col + 1) * lda];
- double a2 = a[offseta + row + (col + 2) * lda];
- double a3 = a[offseta + row + (col + 3) * lda];
- y[offsety + jy] += alphaxix0 * a0 + alphaxix1 * a1 + alphaxix2 * a2 + alphaxix3 * a3;
- double x0 = x[offsetx + jx];
- sumiy0 += x0 * a0;
- sumiy1 += x0 * a1;
- sumiy2 += x0 * a2;
- sumiy3 += x0 * a3;
- }
- double a00 = a[offseta + (row + 0) + (col + 0) * lda];
- double a01 = a[offseta + (row + 0) + (col + 1) * lda];
- double a02 = a[offseta + (row + 0) + (col + 2) * lda];
- double a03 = a[offseta + (row + 0) + (col + 3) * lda];
- double a11 = a[offseta + (row + 1) + (col + 1) * lda];
- double a12 = a[offseta + (row + 1) + (col + 2) * lda];
- double a13 = a[offseta + (row + 1) + (col + 3) * lda];
- double a22 = a[offseta + (row + 2) + (col + 2) * lda];
- double a23 = a[offseta + (row + 2) + (col + 3) * lda];
- double a33 = a[offseta + (row + 3) + (col + 3) * lda];
- double xjx0 = x[offsetx + jx + incx * 0];
- double xjx1 = x[offsetx + jx + incx * 1];
- double xjx2 = x[offsetx + jx + incx * 2];
- double xjx3 = x[offsetx + jx + incx * 3];
- sumiy0 += xjx0 * a00
- + xjx1 * a01
- + xjx2 * a02
- + xjx3 * a03;
- sumiy1 += xjx0 * a01
- + xjx1 * a11
- + xjx2 * a12
- + xjx3 * a13;
- sumiy2 += xjx0 * a02
- + xjx1 * a12
- + xjx2 * a22
- + xjx3 * a23;
- sumiy3 += xjx0 * a03
- + xjx1 * a13
- + xjx2 * a23
- + xjx3 * a33;
- if (beta != 0.0) {
- y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
- y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
- y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
- y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
- } else {
- y[offsety + iy + incy * 0] = alpha * sumiy0;
- y[offsety + iy + incy * 1] = alpha * sumiy1;
- y[offsety + iy + incy * 2] = alpha * sumiy2;
- y[offsety + iy + incy * 3] = alpha * sumiy3;
- }
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- double alphaxix = alpha * x[offsetx + ix];
- double sumiy = 0.0;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- double a0 = a[offseta + row + col * lda];
- y[offsety + jy] += alphaxix * a0;
- sumiy += x[offsetx + jx] * a0;
- }
- sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
- if (beta != 0.0) {
- y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
- } else {
- y[offsety + iy] = alpha * sumiy;
- }
- }
- }
-
- protected void dsymvL(int n, double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, double beta, double[] y, int offsety, int incy) {
- // y = beta * y
- if (beta != 1.0) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0;
- }
- }
- }
- // y += alpha * A * x
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- double alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- double alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- double alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- double alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- double sumiy0 = 0.0;
- double sumiy1 = 0.0;
- double sumiy2 = 0.0;
- double sumiy3 = 0.0;
- double a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * lda];
- double a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * lda];
- double a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * lda];
- double a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * lda];
- double a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * lda];
- double a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * lda];
- double a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * lda];
- double a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * lda];
- double a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * lda];
- double a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * lda];
- double x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
- double x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
- double x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
- double x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
- sumiy0 += x0 * a00
- + x1 * a10
- + x2 * a20
- + x3 * a30;
- sumiy1 += x0 * a10
- + x1 * a11
- + x2 * a21
- + x3 * a31;
- sumiy2 += x0 * a20
- + x1 * a21
- + x2 * a22
- + x3 * a32;
- sumiy3 += x0 * a30
- + x1 * a31
- + x2 * a32
- + x3 * a33;
- int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- double a0 = a[offseta + row + (col + 0) * lda];
- double a1 = a[offseta + row + (col + 1) * lda];
- double a2 = a[offseta + row + (col + 2) * lda];
- double a3 = a[offseta + row + (col + 3) * lda];
- y[offsety + jy] += alphaxix0 * a0
- + alphaxix1 * a1
- + alphaxix2 * a2
- + alphaxix3 * a3;
- double xjx = x[offsetx + jx];
- sumiy0 += xjx * a0;
- sumiy1 += xjx * a1;
- sumiy2 += xjx * a2;
- sumiy3 += xjx * a3;
- }
- y[offsety + iy + incy * 0] += alpha * sumiy0;
- y[offsety + iy + incy * 1] += alpha * sumiy1;
- y[offsety + iy + incy * 2] += alpha * sumiy2;
- y[offsety + iy + incy * 3] += alpha * sumiy3;
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- double alphaxix = alpha * x[offsetx + ix];
- double sumiy = 0.0;
- sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * lda];
- int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * lda];
- sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
- }
- y[offsety + iy] += alpha * sumiy;
- }
- }
-
- protected void ssymvK(String uplo, int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- if (alpha == 0.0f) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0f) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0f;
- }
- }
- } else if (lsame("U", uplo)) {
- ssymvU(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- } else if (lsame("L", uplo)) {
- ssymvL(n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
- }
- }
-
- protected void ssymvU(int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- float sumiy0 = 0.0f;
- float sumiy1 = 0.0f;
- float sumiy2 = 0.0f;
- float sumiy3 = 0.0f;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix0 * a[offseta + row + (col + 0) * lda]
- + alphaxix1 * a[offseta + row + (col + 1) * lda]
- + alphaxix2 * a[offseta + row + (col + 2) * lda]
- + alphaxix3 * a[offseta + row + (col + 3) * lda];
- float xjx = x[offsetx + jx];
- sumiy0 += xjx * a[offseta + row + (col + 0) * lda];
- sumiy1 += xjx * a[offseta + row + (col + 1) * lda];
- sumiy2 += xjx * a[offseta + row + (col + 2) * lda];
- sumiy3 += xjx * a[offseta + row + (col + 3) * lda];
- }
- float a00 = a[offseta + (row + 0) + (col + 0) * lda];
- float a01 = a[offseta + (row + 0) + (col + 1) * lda];
- float a02 = a[offseta + (row + 0) + (col + 2) * lda];
- float a03 = a[offseta + (row + 0) + (col + 3) * lda];
- float a11 = a[offseta + (row + 1) + (col + 1) * lda];
- float a12 = a[offseta + (row + 1) + (col + 2) * lda];
- float a13 = a[offseta + (row + 1) + (col + 3) * lda];
- float a22 = a[offseta + (row + 2) + (col + 2) * lda];
- float a23 = a[offseta + (row + 2) + (col + 3) * lda];
- float a33 = a[offseta + (row + 3) + (col + 3) * lda];
- float xjx0 = x[offsetx + jx + incx * 0];
- float xjx1 = x[offsetx + jx + incx * 1];
- float xjx2 = x[offsetx + jx + incx * 2];
- float xjx3 = x[offsetx + jx + incx * 3];
- sumiy0 += xjx0 * a00
- + xjx1 * a01
- + xjx2 * a02
- + xjx3 * a03;
- sumiy1 += xjx0 * a01
- + xjx1 * a11
- + xjx2 * a12
- + xjx3 * a13;
- sumiy2 += xjx0 * a02
- + xjx1 * a12
- + xjx2 * a22
- + xjx3 * a23;
- sumiy3 += xjx0 * a03
- + xjx1 * a13
- + xjx2 * a23
- + xjx3 * a33;
- if (beta != 0.0f) {
- y[offsety + iy + incy * 0] = alpha * sumiy0 + beta * y[offsety + iy + incy * 0];
- y[offsety + iy + incy * 1] = alpha * sumiy1 + beta * y[offsety + iy + incy * 1];
- y[offsety + iy + incy * 2] = alpha * sumiy2 + beta * y[offsety + iy + incy * 2];
- y[offsety + iy + incy * 3] = alpha * sumiy3 + beta * y[offsety + iy + incy * 3];
- } else {
- y[offsety + iy + incy * 0] = alpha * sumiy0;
- y[offsety + iy + incy * 1] = alpha * sumiy1;
- y[offsety + iy + incy * 2] = alpha * sumiy2;
- y[offsety + iy + incy * 3] = alpha * sumiy3;
- }
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- float alphaxix = alpha * x[offsetx + ix];
- float sumiy = 0.0f;
- int row = 0, jx = incx < 0 ? (col - 1) * -incx : 0, jy = incy < 0 ? (col - 1) * -incy : 0;
- for (; row < col; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * lda];
- sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
- }
- sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
- if (beta != 0.0f) {
- y[offsety + iy] = alpha * sumiy + beta * y[offsety + iy];
- } else {
- y[offsety + iy] = alpha * sumiy;
- }
- }
- }
-
- protected void ssymvL(int n, float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, float beta, float[] y, int offsety, int incy) {
- // y = beta * y
- if (beta != 1.0f) {
- for (int i = 0, iy = incy < 0 ? (n - 1) * -incy : 0; i < n; i += 1, iy += incy) {
- if (beta != 0.0f) {
- y[offsety + iy] = beta * y[offsety + iy];
- } else {
- y[offsety + iy] = 0.0f;
- }
- }
- }
- // y += alpha * A * x
- int col = 0, ix = incx < 0 ? (n - 1) * -incx : 0, iy = incy < 0 ? (n - 1) * -incy : 0;
- for (; col < loopBound(n, 4); col += 4, ix += incx * 4, iy += incy * 4) {
- float alphaxix0 = alpha * x[offsetx + ix + incx * 0];
- float alphaxix1 = alpha * x[offsetx + ix + incx * 1];
- float alphaxix2 = alpha * x[offsetx + ix + incx * 2];
- float alphaxix3 = alpha * x[offsetx + ix + incx * 3];
- float sumiy0 = 0.0f;
- float sumiy1 = 0.0f;
- float sumiy2 = 0.0f;
- float sumiy3 = 0.0f;
- float a00 = a[offseta + /*row=*/(col + 0) + (col + 0) * lda];
- float a10 = a[offseta + /*row=*/(col + 1) + (col + 0) * lda];
- float a11 = a[offseta + /*row=*/(col + 1) + (col + 1) * lda];
- float a20 = a[offseta + /*row=*/(col + 2) + (col + 0) * lda];
- float a21 = a[offseta + /*row=*/(col + 2) + (col + 1) * lda];
- float a22 = a[offseta + /*row=*/(col + 2) + (col + 2) * lda];
- float a30 = a[offseta + /*row=*/(col + 3) + (col + 0) * lda];
- float a31 = a[offseta + /*row=*/(col + 3) + (col + 1) * lda];
- float a32 = a[offseta + /*row=*/(col + 3) + (col + 2) * lda];
- float a33 = a[offseta + /*row=*/(col + 3) + (col + 3) * lda];
- float x0 = x[offsetx + (incx < 0 ? (n - (col + 0) - 1) * -incx : (col + 0) * incx)];
- float x1 = x[offsetx + (incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx)];
- float x2 = x[offsetx + (incx < 0 ? (n - (col + 2) - 1) * -incx : (col + 2) * incx)];
- float x3 = x[offsetx + (incx < 0 ? (n - (col + 3) - 1) * -incx : (col + 3) * incx)];
- sumiy0 += x0 * a00
- + x1 * a10
- + x2 * a20
- + x3 * a30;
- sumiy1 += x0 * a10
- + x1 * a11
- + x2 * a21
- + x3 * a31;
- sumiy2 += x0 * a20
- + x1 * a21
- + x2 * a22
- + x3 * a32;
- sumiy3 += x0 * a30
- + x1 * a31
- + x2 * a32
- + x3 * a33;
- int row = col + 4, jx = incx < 0 ? (n - (col + 4) - 1) * -incx : (col + 4) * incx, jy = incy < 0 ? (n - (col + 4) - 1) * -incy : (col + 4) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- float a0 = a[offseta + row + (col + 0) * lda];
- float a1 = a[offseta + row + (col + 1) * lda];
- float a2 = a[offseta + row + (col + 2) * lda];
- float a3 = a[offseta + row + (col + 3) * lda];
- y[offsety + jy] += alphaxix0 * a0
- + alphaxix1 * a1
- + alphaxix2 * a2
- + alphaxix3 * a3;
- float xjx = x[offsetx + jx];
- sumiy0 += xjx * a0;
- sumiy1 += xjx * a1;
- sumiy2 += xjx * a2;
- sumiy3 += xjx * a3;
- }
- y[offsety + iy + incy * 0] += alpha * sumiy0;
- y[offsety + iy + incy * 1] += alpha * sumiy1;
- y[offsety + iy + incy * 2] += alpha * sumiy2;
- y[offsety + iy + incy * 3] += alpha * sumiy3;
- }
- for (; col < n; col += 1, ix += incx, iy += incy) {
- float alphaxix = alpha * x[offsetx + ix];
- float sumiy = 0.0f;
- sumiy += x[offsetx + (incx < 0 ? (n - col - 1) * -incx : col * incx)] * a[offseta + /*row=*/col + col * lda];
- int row = col + 1, jx = incx < 0 ? (n - (col + 1) - 1) * -incx : (col + 1) * incx, jy = incy < 0 ? (n - (col + 1) - 1) * -incy : (col + 1) * incy;
- for (; row < n; row += 1, jx += incx, jy += incy) {
- y[offsety + jy] += alphaxix * a[offseta + row + col * lda];
- sumiy += x[offsetx + jx] * a[offseta + row + col * lda];
- }
- y[offsety + iy] += alpha * sumiy;
- }
- }
-
- protected void dsyrK(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] a, int offseta, int lda) {
- org.netlib.blas.Dsyr.dsyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
- }
-
- protected void ssyrK(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] a, int offseta, int lda) {
- org.netlib.blas.Ssyr.ssyr(uplo, n, alpha, x, offsetx, incx, a, offseta, lda);
- }
-
- protected void dsyr2K(String uplo, int n, double alpha, double[] x, int offsetx, int incx, double[] y, int offsety, int incy, double[] a, int offseta, int lda) {
- org.netlib.blas.Dsyr2.dsyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
-
- protected void ssyr2K(String uplo, int n, float alpha, float[] x, int offsetx, int incx, float[] y, int offsety, int incy, float[] a, int offseta, int lda) {
- org.netlib.blas.Ssyr2.ssyr2(uplo, n, alpha, x, offsetx, incx, y, offsety, incy, a, offseta, lda);
- }
-
- protected void dsyr2kK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
- org.netlib.blas.Dsyr2k.dsyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected void ssyr2kK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb, float beta, float[] c, int offsetc, int ldc) {
- org.netlib.blas.Ssyr2k.ssyr2k(uplo, trans, n, k, alpha, a, offseta, lda, b, offsetb, ldb, beta, c, offsetc, ldc);
- }
-
- protected void dsyrkK(String uplo, String trans, int n, int k, double alpha, double[] a, int offseta, int lda, double beta, double[] c, int offsetc, int ldc) {
- org.netlib.blas.Dsyrk.dsyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
- }
-
- protected void ssyrkK(String uplo, String trans, int n, int k, float alpha, float[] a, int offseta, int lda, float beta, float[] c, int offsetc, int ldc) {
- org.netlib.blas.Ssyrk.ssyrk(uplo, trans, n, k, alpha, a, offseta, lda, beta, c, offsetc, ldc);
- }
-
- protected void dtbmvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtbmv.dtbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void stbmvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stbmv.stbmv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void dtbsvK(String uplo, String trans, String diag, int n, int k, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtbsv.dtbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void stbsvK(String uplo, String trans, String diag, int n, int k, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stbsv.stbsv(uplo, trans, diag, n, k, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void dtpmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtpmv.dtpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected void stpmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stpmv.stpmv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected void dtpsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtpsv.dtpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected void stpsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, float[] x, int offsetx, int incx) {
- org.netlib.blas.Stpsv.stpsv(uplo, trans, diag, n, a, offseta, x, offsetx, incx);
- }
-
- protected void dtrmmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
- org.netlib.blas.Dtrmm.dtrmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected void strmmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
- org.netlib.blas.Strmm.strmm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected void dtrmvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtrmv.dtrmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void strmvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Strmv.strmv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void dtrsmK(String side, String uplo, String transa, String diag, int m, int n, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb) {
- org.netlib.blas.Dtrsm.dtrsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected void strsmK(String side, String uplo, String transa, String diag, int m, int n, float alpha, float[] a, int offseta, int lda, float[] b, int offsetb, int ldb) {
- org.netlib.blas.Strsm.strsm(side, uplo, transa, diag, m, n, alpha, a, offseta, lda, b, offsetb, ldb);
- }
-
- protected void dtrsvK(String uplo, String trans, String diag, int n, double[] a, int offseta, int lda, double[] x, int offsetx, int incx) {
- org.netlib.blas.Dtrsv.dtrsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected void strsvK(String uplo, String trans, String diag, int n, float[] a, int offseta, int lda, float[] x, int offsetx, int incx) {
- org.netlib.blas.Strsv.strsv(uplo, trans, diag, n, a, offseta, lda, x, offsetx, incx);
- }
-
- protected int idamaxK(int n, double[] x, int offsetx, int incx) {
- return org.netlib.blas.Idamax.idamax(n, x, offsetx, incx);
- }
-
- protected int isamaxK(int n, float[] x, int offsetx, int incx) {
- return org.netlib.blas.Isamax.isamax(n, x, offsetx, incx);
- }
-}
diff --git a/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/CRFModel.scala b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/CRFModel.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b76d254f73c58c21ef814f13caa135e2ef8258d3
--- /dev/null
+++ b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/CRFModel.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import java.io._
+import java.nio.file.{Files, Paths, StandardOpenOption}
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.rdd.RDD
+
+trait VerboseMode
+
+case object VerboseLevel1 extends VerboseMode
+
+case object VerboseLevel2 extends VerboseMode
+
+case class CRFModel (
+ head: Array[String],
+ dic: Array[(String, Int)],
+ alpha: Array[Double]) extends Serializable {
+
+ protected def formatVersion = "1.0"
+
+ private var verboseMode: Option[VerboseMode] = None
+
+ private var nBest = 0
+ private var costFactor = 1.0
+
+ def setNBest(nBest: Int): CRFModel = {
+ this.nBest = nBest
+ this
+ }
+
+ def setVerboseMode(mode: VerboseMode): CRFModel = {
+ this.verboseMode = Some(mode)
+ this
+ }
+
+ def setcostFact(cf: Double): CRFModel = {
+ this.costFactor = cf
+ this
+ }
+
+ override def toString: String = {
+ val dicString = dic.map{case(k, v) => k + "|-|" + v.toString}
+ s"${head.mkString("\t")}|--|${dicString.mkString("\t")}|--|${alpha.map(_.toFloat).mkString("\t")}"
+ }
+
+ def toStringHead: String = {
+ val dicString: Array[String] = dic.map{case(k, v) => k + "|-|" + v.toString}
+ s"${head.mkString("\t")}|--|${dicString.mkString("\t")}"
+ }
+
+ def toArrayString: Array[String] = {
+ val dicString: Array[String] = dic.map{case(k, v) => k + "|-|" + v.toString}
+ val alphaString: Array[String] = alpha.map(_.toString)
+ val emptyLine: Array[String] = Array("|--|")
+ head ++ emptyLine ++ dicString ++ emptyLine ++ alphaString
+ }
+
+ /**
+ * Verify CRF model
+ *
+ * @param tests Source files to be verified
+ * @return Source files with the predictive labels
+ */
+ def predict(tests: RDD[Sequence]): RDD[Sequence] = {
+ val bcModel = tests.context.broadcast(this)
+ tests.map { test =>
+ bcModel.value.testCRF(test, costFactor, verboseMode)
+ }
+ }
+
+ def predict(tests: Array[Sequence]): Array[Sequence] = {
+ tests.map(this.testCRF(_, costFactor, verboseMode))
+ }
+ /**
+ * Internal method to test the CRF model
+ *
+ * @param test the sequence to be tested
+ * @return the sequence along with predictive labels
+ */
+ def testCRF(test: Sequence,
+ costFactor: Double, vMode: Option[VerboseMode]): Sequence = {
+ val deFeatureIdx = new FeatureIndex()
+ deFeatureIdx.readModel(this)
+ val tagger = new Tagger(deFeatureIdx.labels.size, TestMode)
+ tagger.setCostFactor(costFactor)
+ tagger.setNBest(nBest)
+ tagger.read(test, deFeatureIdx)
+ deFeatureIdx.buildFeatures(tagger)
+ tagger.parse(deFeatureIdx.alpha, vMode)
+ var Seq: Sequence = null
+ if (vMode.isDefined) {
+ val tokens = new ArrayBuffer[Token]()
+ val labels = deFeatureIdx.labels
+ val tmp = test.toArray
+ for (i <- tmp.indices) {
+ val probMat = new ArrayBuffer[(String, Double)]()
+ vMode match {
+ case Some(VerboseLevel1) =>
+ probMat.append((labels(tagger.result(i)), tagger.probMatrix(i * labels.length + tagger.result(i))))
+ case Some(VerboseLevel2) =>
+ for (j <- labels.indices)
+ probMat.append((labels(j), tagger.probMatrix(i * labels.length + j)))
+ case _ =>
+ }
+ tokens.append(Token.put(labels(tagger.result(i)), tmp(i).tags).setProb(probMat.toArray))
+ }
+ Seq = Sequence(tokens.toArray).setSeqProb(tagger.seqProb)
+ }
+ else {
+ Seq = Sequence(test.toArray.map(x =>
+ Token.put(deFeatureIdx.labels(tagger.result(test.toArray.indexOf(x))), x.tags)
+ ))
+ }
+ if(nBest > 0) {
+ Seq.setCandidates(tagger.topN, tagger.probN, deFeatureIdx.labels )
+ }
+
+ Seq
+ }
+}
+
+object CRFModel {
+ def load(source: String): CRFModel = {
+ val components = source.split("""\|--\|""")
+ require(components.length == 3, "Incompatible formats in Model file")
+ val head = components(0).split("\t")
+ val dic = components(1).split("\t").map(x => {
+ val xx = x.split("""\|-\|""")
+ require(xx.length == 2, "Incompatible formats in Model file")
+ (xx(0), xx(1).toInt)
+ })
+ val alpha = components(2).split("\t").map(_.toDouble)
+ CRFModel(head, dic, alpha)
+ }
+
+ def loadBinaryFile(path: String): CRFModel = {
+ val source = scala.io.Source.fromFile(path + "/head").getLines().toArray.head
+ val components = source.split("""\|--\|""")
+ require(components.length == 2, "Incompatible formats in Model file")
+ val head = components(0).split("\t")
+ val dic = components(1).split("\t").map(x => {
+ val xx = x.split("""\|-\|""")
+ require(xx.length == 2, "Incompatible formats in Model file")
+ (xx(0), xx(1).toInt)
+ })
+ val alpha = Array.fill(head(1).toInt)(0.0)
+ val infile = new FileInputStream(path + "/alpha")
+ val in: DataInputStream = new DataInputStream(infile)
+ for(i <- alpha.indices)
+ alpha(i) = in.readFloat()
+ in.close()
+ CRFModel(head, dic, alpha)
+ }
+
+ def loadArray(source: Array[String]): CRFModel = {
+ val head = new ArrayBuffer[String]()
+ val dic = new ArrayBuffer[String]()
+ val alpha = new ArrayBuffer[String]()
+ var sentinel: Int = 0
+ for(line <- source) {
+ if (line == "|--|") {
+ sentinel += 1
+ }
+ else {
+ sentinel match {
+ case 0 => head.append(line)
+ case 1 => dic.append(line)
+ case 2 => alpha.append(line)
+ case _ => throw new RuntimeException("Incompatible formats in Model")
+ }
+ }
+ }
+ CRFModel(head.toArray, dic.toArray.map(x => {
+ val xx = x.split("""\|-\|""")
+ require(xx.length == 2, "Incompatible formats in Model file")
+ (xx(0), xx(1).toInt)
+ }), alpha.toArray.map(_.toDouble))
+ }
+
+ def save(model: CRFModel): String = {
+ model.toString
+ }
+
+ def saveBinaryFile(model: CRFModel, path: String): Unit = {
+ val head = model.toStringHead
+ new java.io.PrintWriter(path + "/head") { write(head); close() }
+ val outfile = new FileOutputStream(path + "/alpha")
+ val out: DataOutputStream = new DataOutputStream(
+ new BufferedOutputStream(
+ Files.newOutputStream(
+ Paths.get(path + "/alpha"), StandardOpenOption.APPEND
+ )
+ )
+ )
+ val alpha = model.alpha.map(_.toFloat)
+ for(i <- alpha.indices)
+ out.writeFloat(alpha(i))
+ out.close()
+ }
+
+ def saveArray(model: CRFModel): Array[String] = {
+ model.toArrayString
+ }
+}
diff --git a/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Data.scala b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Data.scala
new file mode 100644
index 0000000000000000000000000000000000000000..af5d0485890ef8022ddc9e4e73d46c53e8f0ac3a
--- /dev/null
+++ b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Data.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Class that represents the columns of a token.
+ *
+ * @param label The last column for this token.
+ * @param tags List of tags for this token, expect for the last label.
+ */
+class Token(
+ val label: String,
+ val tags: Array[String]) extends Serializable {
+ var prob : Array[(String, Double)] = null
+
+ def setProb(probMat: Array[(String, Double)]): Token = {
+ this.prob = probMat
+ this
+ }
+
+ def probPrinter(): String = {
+ val strRes = new StringBuffer()
+ strRes.append( tags.mkString("\t") )
+ strRes.append( "\t" + label + "\t")
+ strRes.append(prob.map{
+ case (str, p) => str + "/" + p.toString
+ }.mkString("\t") )
+ strRes.toString
+ }
+
+ override def toString: String = {
+ s"$label|--|${tags.mkString("|-|")}"
+ }
+
+ def compare(other: Token): Int = {
+ if (this.label == other.label) 1 else 0
+ }
+}
+
+object Token {
+ /**
+ * Parses a string resulted from `LabeledToken#toString` into
+ * an [[com.intel.ssg.bdt.nlp.Token]].
+ *
+ */
+ def deSerializer(s: String): Token = {
+ val parts = s.split("""\|--\|""")
+ val label = parts(0)
+ val tags = parts(1).split("""\|-\|""")
+ Token.put(label, tags)
+ }
+
+ def serializer(token: Token): String = {
+ token.toString
+ }
+
+ def put(label: String, tags: Array[String]): Token = {
+ new Token(label, tags)
+ }
+
+ def put(tags: Array[String]): Token = {
+ new Token(null, tags)
+ }
+}
+
+/**
+ * Class that represents the tokens of a sentence.
+ *
+ * @param sequence List of tokens
+ */
+case class Sequence (sequence: Array[Token]) extends Serializable {
+ var seqProb = 0.0
+ lazy val candidates = ArrayBuffer.empty[Sequence]
+
+ def setSeqProb(seqProb: Double): Sequence = {
+ this.seqProb = seqProb
+ this
+ }
+
+ def setCandidates(nBest: ArrayBuffer[Array[Int]],
+ probN: ArrayBuffer[Double],
+ labels: ArrayBuffer[String]): Sequence = {
+ for (i <- nBest.indices) {
+ val tokens = new ArrayBuffer[Token]()
+ for(j <- sequence.indices) {
+ tokens += Token.put(labels(nBest(i)(j)), sequence(j).tags)
+ }
+ candidates += Sequence(tokens.toArray).setSeqProb(probN(i))
+ }
+ this
+ }
+
+ def Print(): String = {
+ val strRes = new ArrayBuffer[String]()
+ strRes.append("#" + "\t" + seqProb.toString)
+ val pairs = this.toArray
+ for(i <- pairs.indices) {
+ strRes.append(pairs(i).tags.mkString("\t") + "\t" + pairs(i).label)
+ }
+ strRes.mkString("\n")
+ }
+
+ def nthPrint(k: Int): String = {
+ val strRes = new ArrayBuffer[String]()
+ strRes.append("#" + k + "\t" +candidates(k).seqProb.toString)
+ val pairs = this.candidates(k).toArray
+ for(i <- pairs.indices) {
+ strRes.append(pairs(i).tags.mkString("\t") + "\t" + pairs(i).label)
+ }
+ strRes.mkString("\n")
+ }
+
+ def nBestPrint(): String = {
+ val idx = candidates.indices
+ idx.map (t => nthPrint(t)).mkString("\n")
+ }
+
+ override def toString: String = {
+ seqProb match {
+ case 0.0 => s"${sequence.mkString("\t")}"
+ case _ => "#" + seqProb.toString + "\t" + s"${sequence.mkString("\t")}"
+ }
+ }
+
+ def toArray: Array[Token] = sequence
+
+ def compare(other: Sequence): Int = {
+ this.toArray.zip(other.toArray).map{case(one, two) => one.compare(two)}.sum
+ }
+
+ def probPrinter(): String = {
+ val strRes = new ArrayBuffer[String]()
+ strRes.append("|-#-|" + seqProb.toString)
+ strRes ++= this.toArray.map(_.probPrinter())
+ strRes.mkString("\n")
+ }
+
+}
+
+object Sequence {
+ def deSerializer(s: String): Sequence = {
+ val tokens = s.split("\t")
+ tokens.head.substring(0, 5) match {
+ case """"\|-#-\|"""" => val seqProb = tokens.head.substring(5).toDouble
+ Sequence(tokens.tail.map(Token.deSerializer)).setSeqProb(seqProb)
+ case _ => Sequence(tokens.map(Token.deSerializer))
+ }
+ }
+ def serializer(sequence: Sequence): String = {
+ sequence.toString
+ }
+}
diff --git a/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/FeatureIndex.scala b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/FeatureIndex.scala
new file mode 100644
index 0000000000000000000000000000000000000000..725172c97497d3dab9a8e92a11b7954cd6e18b4d
--- /dev/null
+++ b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/FeatureIndex.scala
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import breeze.linalg.{DenseVector => BDV}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+class FeatureIndex extends Serializable {
+
+ var maxID = 0
+ var alpha: BDV[Double] = _
+ var tokensSize = 0
+ val unigramTempls = new ArrayBuffer[String]()
+ val bigramTempls = new ArrayBuffer[String]()
+ var labels = new ArrayBuffer[String]()
+ val dic = mutable.HashMap[String, (Int, Int)]()
+ val kMaxContextSize = 4
+ val BOS = Array("_B-1", "_B-2", "_B-3", "_B-4")
+ val EOS = Array("_B+1", "_B+2", "_B+3", "_B+4")
+
+ def initAlpha(): BDV[Double] = {
+ alpha = BDV.zeros[Double](maxID)
+ alpha
+ }
+
+ def openTagSet(sentence: Sequence): FeatureIndex = {
+ val tokenNum = sentence.toArray.map(_.tags.length).distinct
+ require(tokenNum.length == 1,
+ "The number of columns should be fixed in each token!")
+
+ labels.appendAll(sentence.toArray.map(_.label))
+ tokensSize = tokenNum.head
+ this
+ }
+
+ /**
+ * Build feature index
+ */
+ def buildFeatures(tagger: Tagger): Tagger = {
+ List(unigramTempls, bigramTempls).foreach{ templs =>
+ tagger.x.foreach { token =>
+ if (tagger.x.head != token || templs.head.head.equals('U')) {
+ tagger.featureCacheIndex.append(tagger.featureCache.length)
+ templs.foreach { templ =>
+ val os = applyRule(templ, tagger.x.indexOf(token), tagger)
+ val id = dic.getOrElse(os, (-1, 0))._1
+ if (id != -1) tagger.featureCache.append(id)
+ }
+ tagger.featureCache.append(-1)
+ }
+ }
+ }
+ tagger
+ }
+
+ def buildDictionary(tagger: Tagger): mutable.HashMap[String, Int] = {
+ val dicLocal = mutable.HashMap[String, Int]()
+ List(unigramTempls, bigramTempls).foreach{ templs =>
+ tagger.x.foreach{ token =>
+ if (tagger.x.head != token || templs.head.head.equals('U')) {
+ templs.foreach{ templ =>
+ val os = applyRule(templ, tagger.x.indexOf(token), tagger)
+ if (dicLocal.get(os).isEmpty) {
+ dicLocal.update(os, 1)
+ } else {
+ val idx = dicLocal.get(os).get + 1
+ dicLocal.update(os, idx)
+ }
+ }
+ }
+ }
+ }
+ dicLocal
+ }
+
+ def applyRule(src: String, idx: Int, tagger: Tagger): String = {
+ val templ = src.split(":")
+ if (templ.size == 2) {
+ val cols = templ(1).split("/").map(_.substring(2))
+ templ(0) + ":" + cols.map(getIndex(_, idx, tagger)).reduce(_ + "/" + _)
+ } else if (templ.size == 1) {
+ templ(0)
+ } else {
+ throw new RuntimeException("Incompatible formats in Template")
+ }
+ }
+
+ def getIndex(src: String, pos: Int, tagger: Tagger): String = {
+ val coor = src.drop(1).dropRight(1).split(",")
+ require(coor.size == 2, "Incompatible formats in Template")
+ val row = coor(0).toInt
+ val col = coor(1).toInt
+ if (row < -kMaxContextSize || row > kMaxContextSize ||
+ col < 0 || col >= tokensSize) {
+ throw new RuntimeException("Incompatible formats in Template")
+ }
+ val idx = pos + row
+ if (idx < 0) {
+ BOS(- idx - 1)
+ } else if (idx >= tagger.x.size) {
+ EOS(idx - tagger.x.size)
+ } else {
+ tagger.x(idx)(col)
+ }
+ }
+
+ /**
+ * Read one template file
+ *
+ * @param lines the template file
+ */
+ def openTemplate(lines: Array[String]): Unit = {
+ var i: Int = 0
+ lines.foreach { t =>
+ t.head match {
+ case 'U' => unigramTempls += t
+ case 'B' => bigramTempls += t
+ case '#' =>
+ case _ => throw new RuntimeException("Incompatible formats in Templates")
+ }}
+ }
+
+ def saveModel: CRFModel = {
+ val head = new ArrayBuffer[String]()
+
+ head.append("maxid:")
+ head.append(maxID.toString)
+ head.append("cost-factor:")
+ head.append(1.0.toString)
+ head.append("xsize:")
+ head.append(tokensSize.toString)
+ head.append("Labels:")
+ labels.foreach(head.append(_))
+ head.append("UGrams:")
+ unigramTempls.foreach(head.append(_))
+ head.append("BGrams:")
+ bigramTempls.foreach(head.append(_))
+
+ CRFModel(head.toArray, dic.map { case (k, v) => (k, v._1) }.toArray, alpha.toArray)
+ }
+
+ def readModel(models: CRFModel): this.type = {
+ val contents: Array[String] = models.head
+ models.dic.foreach{case(k, v) => dic.update(k, (v, 1))}
+ alpha = new BDV(models.alpha)
+
+ var i: Int = 0
+ var readMaxId: Boolean = false
+ var readCostFactor: Boolean = false
+ var readXSize: Boolean = false
+ var readLabels: Boolean = false
+ var readUGrams: Boolean = false
+ var readBGrams: Boolean = false
+ val alpha_tmp = new ArrayBuffer[Double]()
+ while (i < contents.length) {
+ contents(i) match {
+ case "maxid:" =>
+ readMaxId = true
+ case "cost-factor:" =>
+ readMaxId = false
+ readCostFactor = true
+ case "xsize:" =>
+ readCostFactor = false
+ readXSize = true
+ case "Labels:" =>
+ readXSize = false
+ readLabels = true
+ case "UGrams:" =>
+ readLabels = false
+ readUGrams = true
+ case "BGrams:" =>
+ readUGrams = false
+ readBGrams = true
+ case _ =>
+ i -= 1
+ }
+ i += 1
+ if (readMaxId) {
+ maxID = contents(i).toInt
+ } else if (readXSize) {
+ tokensSize = contents(i).toInt
+ } else if (readLabels) {
+ labels.append(contents(i))
+ } else if (readUGrams) {
+ unigramTempls.append(contents(i))
+ } else if (readBGrams) {
+ bigramTempls.append(contents(i))
+ }
+ i += 1
+ }
+ this
+ }
+
+ def openTagSetDist(trains: RDD[Sequence]) {
+ val features: RDD[FeatureIndex] = trains.map(new FeatureIndex().openTagSet)
+ val tokensSizeCollect = features.map(_.tokensSize).distinct().collect()
+ require(tokensSizeCollect.length == 1,
+ "The number of columns should be fixed in each token!")
+ tokensSize = tokensSizeCollect.head
+ labels = trains.map(f => toHashSet(f.sequence))
+ .reduce((a, b) => merge(a, b)).toArray.to[ArrayBuffer]
+ }
+
+ def toHashSet(tokens: Array[Token]): mutable.HashSet[String] = {
+ val labelSet : mutable.HashSet[String] = new mutable.HashSet[String]
+ for (t <- tokens) {
+ labelSet.add(t.label)
+ }
+ labelSet
+ }
+
+ def merge(a: mutable.HashSet[String], b: mutable.HashSet[String]): mutable.HashSet[String] = {
+ val labelSet: mutable.HashSet[String] = new mutable.HashSet[String]
+ for (t <- a) {
+ labelSet.add(t)
+ }
+ for (t <- b) {
+ labelSet.add(t)
+ }
+ labelSet
+ }
+
+ def buildDictionaryDist(
+ taggers: RDD[Tagger],
+ bcFeatureIdxI: Broadcast[FeatureIndex], freq: Int) {
+ // filter : use features that occur no less than freq(default 1)
+ val dictionary = taggers.flatMap(tagger => {
+ bcFeatureIdxI.value.buildDictionary(tagger)
+ }).reduceByKey(_ + _)
+ .filter(_._2 >= freq)
+ val dictionaryUni: RDD[(String, (Int, Int))] = dictionary.filter(_._1.head == 'U')
+ .zipWithIndex()
+ .map { case((feature, frequency), featureID) =>
+ (feature, (featureID.toInt * bcFeatureIdxI.value.labels.size, frequency))
+ }
+ val bcOffSet = taggers.context.broadcast(dictionaryUni.count().toInt * labels.size)
+ val dictionaryBi: RDD[(String, (Int, Int))] = dictionary.filter(_._1.head == 'B').zipWithIndex()
+ .map{ case((feature, frequency), featureID) =>
+ (feature, (featureID.toInt * bcFeatureIdxI.value.labels.size * bcFeatureIdxI.value.labels.size + bcOffSet.value, frequency))
+ }
+
+ val dictionaryGram = dictionaryUni.union(dictionaryBi).collect()
+
+ dictionaryGram.foreach{case(k, v) => dic.update(k, v)}
+ maxID = dictionaryGram.map(_._2._1).max + labels.size * labels.size
+
+ }
+}
diff --git a/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Graph.scala b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Graph.scala
new file mode 100644
index 0000000000000000000000000000000000000000..22b13fc46d924d680bcf96a8adfa7ff8211d5a44
--- /dev/null
+++ b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Graph.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import breeze.linalg.{Vector => BV}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[nlp] class Node extends Serializable {
+ var x = 0
+ var y = 0
+ var alpha = 0.0
+ var beta = 0.0
+ var cost = 0.0
+ var bestCost = 0.0
+ var prev: Option[Node] = None
+ var fVector = 0
+ val lPath = new ArrayBuffer[Path]()
+ val rPath = new ArrayBuffer[Path]()
+
+
+ /**
+ * simplify the log likelihood.
+ */
+ def logSumExp(x: Double, y: Double, flg: Boolean): Double = {
+ val MINUS_LOG_EPSILON = 50.0
+ if (flg) y
+ else {
+ val vMin: Double = math.min(x, y)
+ val vMax: Double = math.max(x, y)
+ if (vMax > vMin + MINUS_LOG_EPSILON) vMax else vMax + math.log(math.exp(vMin - vMax) + 1.0)
+ }
+ }
+
+ def calcAlpha(nodes: ArrayBuffer[Node]): Unit = {
+ alpha = 0.0
+ for(i <- lPath.indices)
+ alpha = logSumExp(alpha, lPath(i).cost + nodes(lPath(i).lNode).alpha, i == 0)
+ alpha += cost
+ }
+
+ def calcBeta(nodes: ArrayBuffer[Node]): Unit = {
+ beta = 0.0
+ for(i <- rPath.indices)
+ beta = logSumExp(beta, rPath(i).cost + nodes(rPath(i).rNode).beta, i == 0)
+ beta += cost
+ }
+
+ def calExpectation(
+ expected: BV[Double],
+ Z: Double,
+ size: Int,
+ featureCache: ArrayBuffer[Int],
+ nodes: ArrayBuffer[Node]): Unit = {
+ val c: Double = math.exp(alpha + beta -cost - Z)
+
+ var idx: Int = fVector
+ while (featureCache(idx) != -1) {
+ expected(featureCache(idx) + y) += c
+ idx += 1
+ }
+
+ for(i <- lPath.indices)
+ lPath(i).calExpectation(expected, Z, size, featureCache, nodes)
+
+ }
+}
+
+private[nlp] class Path extends Serializable {
+ var rNode = 0
+ var lNode = 0
+ var cost = 0.0
+ var fVector = 0
+
+ def calExpectation(
+ expected: BV[Double],
+ Z: Double,
+ size: Int,
+ featureCache: ArrayBuffer[Int],
+ nodes: ArrayBuffer[Node]): Unit = {
+ val c: Double = math.exp(nodes(lNode).alpha + cost + nodes(rNode).beta - Z)
+ var idx: Int = fVector
+
+ while (featureCache(idx) != -1) {
+ expected(featureCache(idx) + nodes(lNode).y * size + nodes(rNode).y) += c
+ idx += 1
+ }
+ }
+
+ def add(lnd: Int, rnd: Int, nodes: ArrayBuffer[Node]): Unit = {
+ lNode = lnd
+ rNode = rnd
+ nodes(lNode).rPath.append(this)
+ nodes(rNode).lPath.append(this)
+ }
+}
diff --git a/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Tagger.scala b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Tagger.scala
new file mode 100644
index 0000000000000000000000000000000000000000..404f7bc53f63043663ea84cc0bdc7b51ed1a6749
--- /dev/null
+++ b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/Tagger.scala
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// scalastyle:off
+package com.intel.ssg.bdt.nlp
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import breeze.linalg.{DenseVector => BDV, Vector => BV}
+
+private[nlp] trait Mode
+
+private[nlp] case object LearnMode extends Mode
+
+private[nlp] case object TestMode extends Mode
+
+private[nlp] case class QueueElement(node : Node, fx : Double, gx : Double, next : QueueElement)
+
+class Tagger (
+ ySize: Int,
+ mode: Mode = LearnMode) extends Serializable {
+ var nBest = 0
+ var cost = 0.0
+ var Z = 0.0
+ var obj = 0.0
+ var costFactor = 1.0
+ var buildFlag = false
+ val x = new ArrayBuffer[Array[String]]()
+ val nodes = new ArrayBuffer[Node]()
+ val answer = new ArrayBuffer[Int]()
+ val result = new ArrayBuffer[Int]()
+ val featureCache = new ArrayBuffer[Int]()
+ val featureCacheIndex = new ArrayBuffer[Int]()
+ val probMatrix = new ArrayBuffer[Double]()
+ var seqProb = 0.0
+ lazy val topN = ArrayBuffer.empty[Array[Int]]
+ lazy val topNResult = ArrayBuffer.empty[Int]
+ lazy val probN = ArrayBuffer.empty[Double]
+ lazy val agenda = mutable.PriorityQueue.empty[QueueElement] (
+ Ordering.by((_: QueueElement).fx).reverse
+ )
+
+ def setCostFactor(costFactor: Double): Tagger = {
+ this.costFactor = costFactor
+ this
+ }
+
+ def setNBest(nBest: Int): Tagger = {
+ this.nBest = nBest
+ this
+ }
+
+ def read(lines: Sequence, feature_idx: FeatureIndex): Unit = {
+ lines.toArray.foreach{ t =>
+ mode match {
+ case LearnMode =>
+ for (y <- feature_idx.labels if y.equals(t.label))
+ answer.append(feature_idx.labels.indexOf(y))
+ x.append(t.tags)
+ case TestMode =>
+ x.append(t.tags)
+ answer.append(0)
+ }
+ result.append(0)
+ }
+ }
+
+
+ /**
+ * Set node relationship and its feature index.
+ * Node represents a token.
+ */
+ def rebuildFeatures(): Unit = {
+
+ nodes ++= Array.fill(x.length * ySize)(new Node)
+ nodes.zipWithIndex.foreach{ case(n, index) =>
+ n.x = index / ySize
+ n.y = index - n.x * ySize
+ n.fVector = featureCacheIndex(n.x)
+ }
+
+ nodes.filter(_.x > 0).foreach{ n =>
+ val paths = Array.fill(ySize)(new Path)
+ paths.zipWithIndex.foreach { case(p, indexP) =>
+ p.fVector = featureCacheIndex(n.x + x.length - 1)
+ p.add((n.x - 1) * ySize + n.y, n.x * ySize + indexP, nodes)
+ }
+ }
+ }
+
+ /**
+ * Calculate the expectation of each node
+ */
+ def forwardBackward(): Unit = {
+ nodes.foreach(_.calcAlpha(nodes))
+ nodes.reverse.foreach(_.calcBeta(nodes))
+ Z = 0.0
+ nodes.filter(_.x == 0).foreach(n => Z = n.logSumExp(Z, n.beta, n.y == 0))
+ }
+
+ /**
+ * Get the max expectation in the nodes and predicts the most likely label
+ */
+ def viterbi(): Unit = {
+ var bestCost = Double.MinValue
+ var best: Option[Node] = None
+
+ nodes.foreach { n =>
+ bestCost = Double.MinValue
+ best = None
+ n.lPath.foreach { p =>
+ val cost = nodes(p.lNode).bestCost + p.cost + n.cost
+ if (cost > bestCost) {
+ bestCost = cost
+ best = Some(nodes(p.lNode))
+ }
+ }
+ n.prev = best
+ best match {
+ case None =>
+ n.bestCost = n.cost
+ case _ =>
+ n.bestCost = bestCost
+ }
+ }
+
+ var nd: Option[Node] = Some(nodes.filter(_.x == x.length - 1).max(Ordering.by((_:Node).bestCost)))
+
+ while (nd.isDefined) {
+ result.update(nd.get.x, nd.get.y)
+ nd = nd.get.prev
+ }
+
+ cost = - nodes((x.length - 1)*ySize + result.last).bestCost
+ }
+
+ def gradient(expected: BV[Double], alpha: BDV[Double]): Double = {
+
+ buildLattice(alpha)
+ forwardBackward()
+
+ nodes.foreach(_.calExpectation(expected, Z, ySize, featureCache, nodes))
+
+ var s: Double = 0.0
+ for (i <- x.indices) {
+ var rIdx = nodes(i * ySize + answer(i)).fVector
+ while (featureCache(rIdx) != -1) {
+ expected(featureCache(rIdx) + answer(i)) -= 1.0
+ rIdx += 1
+ }
+ s += nodes(i * ySize + answer(i)).cost
+ var j = 0
+ while (j < nodes(i * ySize + answer(i)).lPath.length) {
+ val lNode = nodes(nodes(i * ySize + answer(i)).lPath(j).lNode)
+ val rNode = nodes(nodes(i * ySize + answer(i)).lPath(j).rNode)
+ val lPath = nodes(i * ySize + answer(i)).lPath(j)
+ if (lNode.y == answer(lNode.x)) {
+ rIdx = lPath.fVector
+ while (featureCache(rIdx) != -1) {
+ expected(featureCache(rIdx) + lNode.y * ySize + rNode.y) -= 1.0
+ rIdx += 1
+ }
+ s += lPath.cost
+ }
+ j += 1
+ }
+ }
+
+ Z - s
+ }
+
+ def probCalculate(): Unit = {
+ probMatrix ++= Array.fill(x.length * ySize)(0.0)
+ var idx: Int = 0
+ nodes.foreach{ n =>
+ idx = n.x * ySize + n.y
+ probMatrix(idx) = Math.exp(n.alpha + n.beta - n.cost - Z)
+ }
+ this.seqProb = Math.exp(- cost - Z)
+
+ }
+
+ def clear(): Unit = {
+ nodes foreach{ n =>
+ n.lPath.clear()
+ n.rPath.clear()
+ }
+ nodes.clear()
+ }
+
+ def parse(alpha: BDV[Double], mode: Option[VerboseMode]): Unit = {
+ buildLattice(alpha)
+ if (nBest > 0 || mode.isDefined) {
+ forwardBackward()
+ viterbi()
+ probCalculate()
+ } else {
+ viterbi()
+ }
+ if(nBest > 0) {
+ // initialize nBest
+ if (agenda.nonEmpty) agenda.clear()
+ nodes.slice((x.size - 1) * ySize, x.size * ySize - 1)
+ .foreach(n => agenda += QueueElement(n, - n.bestCost, - n.cost, null))
+ // find nBest
+ for(i <- 0 until this.nBest) {
+ topNResult.clear()
+ if (!nextNode) {
+ return
+ }
+ probN.append(Math.exp(- cost - Z))
+ topN.append(topNResult.toArray)
+ }
+ }
+ }
+
+ def buildLattice(alpha: BDV[Double]): Unit = {
+
+ if (!buildFlag) {
+ rebuildFeatures()
+ buildFlag = true
+ }
+
+ nodes.foreach { n =>
+ val nn = calcCost(n, alpha)
+ nn.lPath.foreach(calcCost(_, alpha))
+ nn
+ }
+ }
+
+ def calcCost(n: Node, alpha: BDV[Double]): Node = {
+ var cd: Double = 0.0
+ var idx: Int = n.fVector
+ n.cost = 0.0
+
+ while (featureCache(idx) != -1) {
+ cd += alpha(featureCache(idx) + n.y)
+ n.cost = cd * costFactor
+ idx += 1
+ }
+
+ n
+ }
+
+ def calcCost(p: Path, alpha: BDV[Double]): Path = {
+ var cd: Double = 0.0
+ var idx: Int = p.fVector
+ p.cost = 0.0
+
+ while (featureCache(idx) != -1) {
+ cd += alpha(featureCache(idx) +
+ nodes(p.lNode).y * ySize + nodes(p.rNode).y)
+ p.cost = cd * costFactor
+ idx += 1
+ }
+
+ p
+ }
+
+ def nextNode: Boolean = {
+ var top: QueueElement = null
+ var rNode: Node = null
+ while(agenda.nonEmpty) {
+ top = agenda.dequeue()
+ rNode = top.node
+ if(rNode.x == 0) {
+ var n: QueueElement = top
+ for(i <- x.indices) {
+ topNResult.append(n.node.y)
+ n = n.next
+ }
+ cost = top.gx
+ return true
+ }
+ rNode.lPath.foreach { p =>
+ val gx = -nodes(p.lNode).cost - p.cost + top.gx
+ val fx = - nodes(p.lNode).bestCost - p.cost + top.gx
+ agenda += QueueElement(nodes(p.lNode), fx, gx, top)
+ }
+ }
+ false
+ }
+}
diff --git a/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/UpdaterCRF.scala b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/UpdaterCRF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..20c51273532a9ab9fe353ac80ed7bdec9a97d06e
--- /dev/null
+++ b/ml-core/src/main/scala/com/intel/ssg/bdt/nlp/UpdaterCRF.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.ssg.bdt.nlp
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.mllib.linalg.{Vector => SparkVector}
+import org.apache.spark.mllib.optimization.Updater
+
+trait UpdaterCRF extends Updater {
+ def compute(
+ weightsOld: SparkVector,
+ gradient: SparkVector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double): (SparkVector, Double) = {
+ throw new Exception("The original compute() method is not supported")
+ }
+
+ def computeCRF(
+ weightsOld: BDV[Double],
+ gradient: BDV[Double],
+ regParam: Double): (BDV[Double], Double)
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala b/ml-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b7ab9aac0fcb5f7689fd3759bf226a3de0bed9cd
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala
@@ -0,0 +1,241 @@
+//scalastyle:off
+package org.apache.spark.ml.classification
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.knn._
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.SparkException
+import org.apache.spark.ml.stat.MultiClassSummarizer
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * [[https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm]] for classification.
+ * An object is classified by a majority vote of its neighbors, with the object being assigned to
+ * the class most common among its k nearest neighbors.
+ */
+class KNNClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, KNNClassifier, KNNClassificationModel]
+ with KNNParams with HasWeightCol {
+
+ def this() = this(Identifiable.randomUID("knnc"))
+
+ /** @group setParam */
+ override def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ override def setLabelCol(value: String): this.type = {
+ set(labelCol, value)
+
+ if ($(weightCol).isEmpty) {
+ set(inputCols, Array(value))
+ } else {
+ set(inputCols, Array(value, $(weightCol)))
+ }
+ }
+
+ //fill in default label col
+ setDefault(inputCols, Array($(labelCol)))
+
+ /** @group setWeight */
+ def setWeightCol(value: String): this.type = {
+ set(weightCol, value)
+
+ if (value.isEmpty) {
+ set(inputCols, Array($(labelCol)))
+ } else {
+ set(inputCols, Array($(labelCol), value))
+ }
+ }
+
+ setDefault(weightCol -> "")
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ def setTopTreeSize(value: Int): this.type = set(topTreeSize, value)
+
+ /** @group setParam */
+ def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value)
+
+ /** @group setParam */
+ def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value)
+
+ /** @group setParam */
+ def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value)
+
+ /** @group setParam */
+ def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override protected def train(dataset: Dataset[_]): KNNClassificationModel = {
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val instances = extractLabeledPoints(dataset).map {
+ case LabeledPoint(label: Double, features: Vector) => (label, features)
+ }
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val labelSummarizer = instances.treeAggregate(
+ new MultiClassSummarizer)(
+ seqOp = (c, v) => (c, v) match {
+ case (labelSummarizer: MultiClassSummarizer, (label: Double, features: Vector)) =>
+ labelSummarizer.add(label)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) =>
+ classSummarizer1.merge(classSummarizer2)
+ })
+
+ val histogram = labelSummarizer.histogram
+ val numInvalid = labelSummarizer.countInvalid
+ val numClasses = histogram.length
+
+ if (numInvalid != 0) {
+ val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
+ s"Found $numInvalid invalid labels."
+ logError(msg)
+ throw new SparkException(msg)
+ }
+
+ val knnModel = copyValues(new KNN()).fit(dataset)
+ knnModel.toNewClassificationModel(uid, numClasses)
+ }
+
+ override def fit(dataset: Dataset[_]): KNNClassificationModel = {
+ // Need to overwrite this method because we need to manually overwrite the buffer size
+ // because it is not supposed to stay the same as the Classifier if user sets it to -1.
+ transformSchema(dataset.schema, logging = true)
+ val model = train(dataset)
+ val bufferSize = model.getBufferSize
+ copyValues(model.setParent(this)).setBufferSize(bufferSize)
+ }
+
+ override def copy(extra: ParamMap): KNNClassifier = defaultCopy(extra)
+}
+
+class KNNClassificationModel private[ml](
+ override val uid: String,
+ val topTree: Broadcast[Tree],
+ val subTrees: RDD[Tree],
+ val _numClasses: Int
+) extends ProbabilisticClassificationModel[Vector, KNNClassificationModel]
+ with KNNModelParams with HasWeightCol with Serializable {
+ require(subTrees.getStorageLevel != StorageLevel.NONE,
+ "KNNModel is not designed to work with Trees that have not been cached")
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ def setBufferSize(value: Double): this.type = set(bufferSize, value)
+
+ override def numClasses: Int = _numClasses
+
+ //TODO: This can benefit from DataSet API
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val getWeight: Row => Double = {
+ if($(weightCol).isEmpty) {
+ r => 1.0
+ } else {
+ r => r.getDouble(1)
+ }
+ }
+
+ val neighborRDD : RDD[(Long, Array[(Row, Double)])] = transform(dataset, topTree, subTrees)
+ val merged = neighborRDD
+ .map {
+ case (id, labelsDists) =>
+ val (labels, _) = labelsDists.unzip
+ val vector = new Array[Double](numClasses)
+ var i = 0
+ while (i < labels.length) {
+ vector(labels(i).getDouble(0).toInt) += getWeight(labels(i))
+ i += 1
+ }
+ val rawPrediction = Vectors.dense(vector)
+ lazy val probability = raw2probability(rawPrediction)
+ lazy val prediction = probability2prediction(probability)
+
+ val values = new ArrayBuffer[Any]
+ if ($(rawPredictionCol).nonEmpty) {
+ values.append(rawPrediction)
+ }
+ if ($(probabilityCol).nonEmpty) {
+ values.append(probability)
+ }
+ if ($(predictionCol).nonEmpty) {
+ values.append(prediction)
+ }
+
+ (id, values)
+ }
+
+ dataset.sqlContext.createDataFrame(
+ dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) }
+ .leftOuterJoin(merged) //make sure we don't lose any observations
+ .map {
+ case (i, (row, values)) => Row.fromSeq(row.toSeq ++ values.get)
+ },
+ transformSchema(dataset.schema)
+ )
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ var transformed = schema
+ if ($(rawPredictionCol).nonEmpty) {
+ transformed = SchemaUtils.appendColumn(transformed, $(rawPredictionCol), new VectorUDT)
+ }
+ if ($(probabilityCol).nonEmpty) {
+ transformed = SchemaUtils.appendColumn(transformed, $(probabilityCol), new VectorUDT)
+ }
+ if ($(predictionCol).nonEmpty) {
+ transformed = SchemaUtils.appendColumn(transformed, $(predictionCol), DoubleType)
+ }
+ transformed
+ }
+
+ override def copy(extra: ParamMap): KNNClassificationModel = {
+ val copied = new KNNClassificationModel(uid, topTree, subTrees, numClasses)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ var i = 0
+ val size = dv.size
+
+ var sum = 0.0
+ while (i < size) {
+ sum += dv.values(i)
+ i += 1
+ }
+
+ i = 0
+ while (i < size) {
+ dv.values(i) /= sum
+ i += 1
+ }
+
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in KNNClassificationModel:" +
+ " raw2probabilitiesInPlace encountered SparseVector")
+ }
+ }
+
+ override def predictRaw(features: Vector): Vector = {
+ throw new SparkException("predictRaw function should not be called directly since kNN prediction is done in distributed fashion. Use transform instead.")
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala b/ml-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala
new file mode 100644
index 0000000000000000000000000000000000000000..205f400b86a3a17ca9e6b955c098d651c9c88619
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala
@@ -0,0 +1,590 @@
+//scalastyle:off
+package org.apache.spark.ml.knn
+
+import breeze.linalg.{DenseVector, Vector => BV}
+import breeze.stats._
+import breeze.stats.meanAndVariance.MeanAndVariance
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.classification.KNNClassificationModel
+import org.apache.spark.ml.knn.KNN.{KNNPartitioner, RowWithVector, VectorWithNorm}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.regression.KNNRegressionModel
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
+import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.{HashPartitioner, Partitioner}
+import org.apache.log4j
+import org.apache.spark.mllib.knn.KNNUtils
+
+import scala.annotation.tailrec
+import scala.collection.mutable.ArrayBuffer
+import scala.util.hashing.byteswap64
+
+// features column => vector, input columns => auxiliary columns to return by KNN model
+private[ml] trait KNNModelParams extends Params with HasFeaturesCol with HasInputCols {
+ /**
+ * Param for the column name for returned neighbors.
+ * Default: "neighbors"
+ *
+ * @group param
+ */
+ val neighborsCol = new Param[String](this, "neighborsCol", "column names for returned neighbors")
+
+ /** @group getParam */
+ def getNeighborsCol: String = $(neighborsCol)
+
+ /**
+ * Param for distance column that will create a distance column of each nearest neighbor
+ * Default: no distance column will be used
+ *
+ * @group param
+ */
+ val distanceCol = new Param[String](this, "distanceCol", "column that includes each neighbors' distance as an additional column")
+
+ /** @group getParam */
+ def getDistanceCol: String = $(distanceCol)
+
+ /**
+ * Param for number of neighbors to find (> 0).
+ * Default: 5
+ *
+ * @group param
+ */
+ val k = new IntParam(this, "k", "number of neighbors to find", ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getK: Int = $(k)
+
+ /**
+ * Param for maximum distance to find neighbors
+ * Default: Double.PositiveInfinity
+ *
+ * @group param
+ */
+ val maxDistance = new DoubleParam(this, "maxNeighbors", "maximum distance to find neighbors", // todo: maxDistance or maxNeighbors?
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getMaxDistance: Double = $(maxDistance)
+
+ /**
+ * Param for size of buffer used to construct spill trees and top-level tree search.
+ * Note the buffer size is 2 * tau as described in the paper.
+ *
+ * When buffer size is 0.0, the tree itself reverts to a metric tree.
+ * -1.0 triggers automatic effective nearest neighbor distance estimation.
+ *
+ * Default: -1.0
+ *
+ * @group param
+ */
+ val bufferSize = new DoubleParam(this, "bufferSize",
+ "size of buffer used to construct spill trees and top-level tree search", ParamValidators.gtEq(-1.0))
+
+ /** @group getParam */
+ def getBufferSize: Double = $(bufferSize)
+
+ private[ml] def transform(data: RDD[Vector], topTree: Broadcast[Tree], subTrees: RDD[Tree]): RDD[(Long, Array[(Row,Double)])] = {
+ val searchData = data.zipWithIndex()
+ .flatMap {
+ case (vector, index) =>
+ val vectorWithNorm = new VectorWithNorm(vector)
+ val idx = KNN.searchIndices(vectorWithNorm, topTree.value, $(bufferSize))
+ .map(i => (i, (vectorWithNorm, index)))
+
+ assert(idx.nonEmpty, s"indices must be non-empty: $vector ($index)")
+ idx
+ }
+ .partitionBy(new HashPartitioner(subTrees.partitions.length))
+
+ // for each partition, search points within corresponding child tree
+ val results = searchData.zipPartitions(subTrees) {
+ (childData, trees) =>
+ val tree = trees.next()
+ assert(!trees.hasNext)
+ childData.flatMap {
+ case (_, (point, i)) =>
+ tree.query(point, $(k)).collect {
+ case (neighbor, distance) if distance <= $(maxDistance) =>
+ (i, (neighbor.row, distance))
+ }
+ }
+ }
+
+ // merge results by point index together and keep topK results
+ results.topByKey($(k))(Ordering.by(-_._2))
+ .map { case (i, seq) => (i, seq) }
+ }
+
+ private[ml] def transform(dataset: Dataset[_], topTree: Broadcast[Tree], subTrees: RDD[Tree]): RDD[(Long, Array[(Row, Double)])] = {
+ transform(dataset.select($(featuresCol)).rdd.map(_.getAs[Vector](0)), topTree, subTrees)
+ }
+
+}
+
+private[ml] trait KNNParams extends KNNModelParams with HasSeed {
+ /**
+ * Param for number of points to sample for top-level tree (> 0).
+ * Default: 1000
+ *
+ * @group param
+ */
+ val topTreeSize = new IntParam(this, "topTreeSize", "number of points to sample for top-level tree", ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getTopTreeSize: Int = $(topTreeSize)
+
+ /**
+ * Param for number of points at which to switch to brute-force for top-level tree (> 0).
+ * Default: 5
+ *
+ * @group param
+ */
+ val topTreeLeafSize = new IntParam(this, "topTreeLeafSize",
+ "number of points at which to switch to brute-force for top-level tree", ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getTopTreeLeafSize: Int = $(topTreeLeafSize)
+
+ /**
+ * Param for number of points at which to switch to brute-force for distributed sub-trees (> 0).
+ * Default: 20
+ *
+ * @group param
+ */
+ val subTreeLeafSize = new IntParam(this, "subTreeLeafSize",
+ "number of points at which to switch to brute-force for distributed sub-trees", ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getSubTreeLeafSize: Int = $(subTreeLeafSize)
+
+ /**
+ * Param for number of sample sizes to take when estimating buffer size (at least two samples).
+ * Default: 100 to 1000 by 100
+ *
+ * @group param
+ */
+ val bufferSizeSampleSizes = new IntArrayParam(this, "bufferSizeSampleSize", // todo: should this have an 's' at the end?
+ "number of sample sizes to take when estimating buffer size", { arr: Array[Int] => arr.length > 1 && arr.forall(_ > 0) })
+
+ /** @group getParam */
+ def getBufferSizeSampleSizes: Array[Int] = $(bufferSizeSampleSizes)
+
+ /**
+ * Param for fraction of total points at which spill tree reverts back to metric tree
+ * if either child contains more points (0 <= rho <= 1).
+ * Default: 70%
+ *
+ * @group param
+ */
+ val balanceThreshold = new DoubleParam(this, "balanceThreshold",
+ "fraction of total points at which spill tree reverts back to metric tree if either child contains more points",
+ ParamValidators.inRange(0, 1))
+
+ /** @group getParam */
+ def getBalanceThreshold: Double = $(balanceThreshold)
+
+ setDefault(topTreeSize -> 1000, topTreeLeafSize -> 10, subTreeLeafSize -> 30,
+ bufferSize -> -1.0, bufferSizeSampleSizes -> (100 to 1000 by 100).toArray, balanceThreshold -> 0.7,
+ k -> 5, neighborsCol -> "neighbors", distanceCol -> "", maxDistance -> Double.PositiveInfinity)
+
+ /**
+ * Validates and transforms the input schema.
+ *
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ val auxFeatures = $(inputCols).map(c => schema(c))
+ val schemaWithNeighbors = SchemaUtils.appendColumn(schema, $(neighborsCol), ArrayType(StructType(auxFeatures)))
+
+ if ($(distanceCol).isEmpty) {
+ schemaWithNeighbors
+ } else {
+ SchemaUtils.appendColumn(schemaWithNeighbors, $(distanceCol), ArrayType(DoubleType))
+ }
+ }
+}
+
+/**
+ * kNN Model facilitates k-Nestrest Neighbor search by storing distributed hybrid spill tree.
+ * Top level tree is a MetricTree but instead of using back tracking, it searches all possible leaves in parallel
+ * to avoid multiple iterations. It uses the same buffer size that is used in model training, when the search
+ * vector falls into the buffer zone of the node, it dispatches search to both children.
+ *
+ * A high level overview of the search phases is as follows:
+ *
+ * 1. For each vector to search, go through the top level tree to output a pair of (index, point)
+ * 1. Repartition search points by partition index
+ * 1. Search each point through the hybrid spill tree in that particular partition
+ * 1. For each point, merge results from different partitions and keep top k results.
+ *
+ */
+class KNNModel private[ml](
+ override val uid: String,
+ val topTree: Broadcast[Tree],
+ val subTrees: RDD[Tree]
+) extends Model[KNNModel] with KNNModelParams {
+ require(subTrees.getStorageLevel != StorageLevel.NONE,
+ "KNNModel is not designed to work with Trees that have not been cached")
+
+ /** @group setParam */
+ def setNeighborsCol(value: String): this.type = set(neighborsCol, value)
+
+ /** @group setParam */
+ def setDistanceCol(value: String): this.type = set(distanceCol, value)
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ def setMaxDistance(value: Double): this.type = set(maxDistance, value)
+
+ /** @group setParam */
+ def setBufferSize(value: Double): this.type = set(bufferSize, value)
+
+ //TODO: All these can benefit from DataSet API
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val merged: RDD[(Long, Array[(Row,Double)])] = transform(dataset, topTree, subTrees)
+
+ val withDistance = $(distanceCol).nonEmpty
+
+ dataset.sqlContext.createDataFrame(
+ dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) }
+ .leftOuterJoin(merged)
+ .map {
+ case (i, (row, neighborsAndDistances)) =>
+ val (neighbors, distances) = neighborsAndDistances.map(_.unzip).getOrElse((Array.empty[Row], Array.empty[Double]))
+ if (withDistance) {
+ Row.fromSeq(row.toSeq :+ neighbors :+ distances)
+ } else {
+ Row.fromSeq(row.toSeq :+ neighbors)
+ }
+ },
+ transformSchema(dataset.schema)
+ )
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ val auxFeatures = $(inputCols).map(c => schema(c))
+ val schemaWithNeighbors = SchemaUtils.appendColumn(schema, $(neighborsCol), ArrayType(StructType(auxFeatures)))
+ if ($(distanceCol).isEmpty) {
+ schemaWithNeighbors
+ } else {
+ SchemaUtils.appendColumn(schemaWithNeighbors, $(distanceCol), ArrayType(DoubleType))
+ }
+ }
+
+ override def copy(extra: ParamMap): KNNModel = {
+ val copied = new KNNModel(uid, topTree, subTrees)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ def toNewClassificationModel(uid: String, numClasses: Int): KNNClassificationModel = {
+ copyValues(new KNNClassificationModel(uid, topTree, subTrees, numClasses))
+ }
+
+ def toNewRegressionModel(uid: String): KNNRegressionModel = {
+ copyValues(new KNNRegressionModel(uid, topTree, subTrees))
+ }
+}
+
+/**
+ * k-Nearest Neighbors (kNN) algorithm
+ *
+ * kNN finds k closest observations in training dataset. It can be used for both classification and regression.
+ * Furthermore it can also be used for other purposes such as input to clustering algorithm.
+ *
+ * While the brute-force approach requires no pre-training, each prediction requires going through the entire training
+ * set resulting O(n log(k)) runtime per individual prediction using a heap keep track of neighbor candidates.
+ * Many different implementations have been proposed such as Locality Sensitive Hashing (LSH), KD-Tree, Metric Tree and etc.
+ * Each algorithm has its shortcomings that prevent them to be effective on large-scale and/or high-dimensional dataset.
+ *
+ * This is an implementation of kNN based upon distributed Hybrid Spill-Trees where training points are organized into
+ * distributed binary trees. The algorithm is designed to support accurate approximate kNN search but by tuning parameters
+ * an exact search can also be performed with cost of additional runtime.
+ *
+ * Each binary tree node is either a
+ *
+ * '''Metric Node''':
+ * Metric Node partition points exclusively into two children by finding two pivot points and divide by middle plane.
+ * When searched, the child whose pivot is closer to query vector is searched first. Back tracking is required to
+ * ensure accuracy in this case, where the other child should be searched if it can possibly contain better neighbor
+ * based upon candidates picked during previous search.
+ *
+ * '''Spill Node''':
+ * Spill Node also partitions points into two children however there are an overlapping buffer between the two pivot
+ * points. The larger the buffer size, the less effective the node eliminates points thus could increase tree height.
+ * When searched, defeatist search is used where only one child is searched and no back tracking happens in this
+ * process. Because of the buffer between two children, we are likely to end up with good enough candidates without
+ * searching the other part of the tree.
+ *
+ * While Spill Node promises O(h) runtime where h is the tree height, the tree is deeper than Metric Tree's O(log n)
+ * height on average. Furthermore, when it comes down to leaves where points are more closer to each other, the static
+ * buffer size means more points will end up in the buffer. Therefore a Balance Threshold (rho) is introduced: when
+ * either child of Spill Node makes up more than rho fraction of the total points at this level, Spill Node is reverted
+ * back to a Metric Node.
+ *
+ * A high level overview of the algorithm is as follows:
+ *
+ * 1. Sample M data points (M is relatively small and can be held in driver)
+ * 1. Build the top level metric tree
+ * 1. Repartition RDD by assigning each point to leaf node of the above tree
+ * 1. Build a hybrid spill tree at each partition
+ *
+ * This concludes the training phase of kNN.
+ * See [[KNNModel]] for details on prediction phase.
+ *
+ *
+ * This algorithm is described in [[http://dx.doi.org/10.1109/WACV.2007.18]] where it was shown to scale well in terms of
+ * number of observations and dimensions, bounded by the available memory across clusters (billions in paper's example).
+ * This implementation adapts the MapReduce algorithm to work with Spark.
+ *
+ */
+class KNN(override val uid: String) extends Estimator[KNNModel] with KNNParams {
+ def this() = this(Identifiable.randomUID("knn"))
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ def setAuxCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ def setTopTreeSize(value: Int): this.type = set(topTreeSize, value)
+
+ /** @group setParam */
+ def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value)
+
+ /** @group setParam */
+ def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value)
+
+ /** @group setParam */
+ def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value)
+
+ /** @group setParam */
+ def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override def fit(dataset: Dataset[_]): KNNModel = {
+ val rand = new XORShiftRandom($(seed))
+ //prepare data for model estimation
+ val data = dataset.selectExpr($(featuresCol), $(inputCols).mkString("struct(", ",", ")"))
+ .rdd
+ .map(row => new RowWithVector(row.getAs[Vector](0), row.getStruct(1)))
+ //sample data to build top-level tree
+ val sampled = data.sample(withReplacement = false, $(topTreeSize).toDouble / dataset.count(), rand.nextLong()).collect()
+ val topTree = MetricTree.build(sampled, $(topTreeLeafSize), rand.nextLong())
+ //build partitioner using top-level tree
+ val part = new KNNPartitioner(topTree)
+ //noinspection ScalaStyle
+ val repartitioned = new ShuffledRDD[RowWithVector, Null, Null](data.map(v => (v, null)), part).keys
+
+ val tau =
+ if ($(balanceThreshold) > 0 && $(bufferSize) < 0) {
+ KNN.estimateTau(data, $(bufferSizeSampleSizes), rand.nextLong())
+ } else {
+ math.max(0, $(bufferSize))
+ }
+ logInfo("Tau is: " + tau)
+
+ val trees = repartitioned.mapPartitionsWithIndex {
+ (partitionId, itr) =>
+ val rand = new XORShiftRandom(byteswap64($(seed) ^ partitionId))
+ val childTree =
+ HybridTree.build(itr.toIndexedSeq, $(subTreeLeafSize), tau, $(balanceThreshold), rand.nextLong())
+
+ Iterator(childTree)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+ // TODO: force persisting trees primarily for benchmark. any reason not to do this for regular runs?
+ trees.count()
+
+ val model = new KNNModel(uid, trees.context.broadcast(topTree), trees).setParent(this)
+ copyValues(model).setBufferSize(tau)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ override def copy(extra: ParamMap): KNN = defaultCopy(extra)
+}
+
+
+object KNN {
+
+ val logger = log4j.Logger.getLogger(classOf[KNN])
+
+ /**
+ * VectorWithNorm can use more efficient algorithm to calculate distance
+ */
+ case class VectorWithNorm(vector: Vector, norm: Double) {
+ def this(vector: Vector) = this(vector, Vectors.norm(vector, 2))
+
+ def this(vector: BV[Double]) = this(Vectors.fromBreeze(vector))
+
+ def fastSquaredDistance(v: VectorWithNorm): Double = {
+ KNNUtils.fastSquaredDistance(vector, norm, v.vector, v.norm)
+ }
+
+ def fastDistance(v: VectorWithNorm): Double = math.sqrt(fastSquaredDistance(v))
+ }
+
+ /**
+ * VectorWithNorm plus auxiliary row information
+ */
+ case class RowWithVector(vector: VectorWithNorm, row: Row) {
+ def this(vector: Vector, row: Row) = this(new VectorWithNorm(vector), row)
+ }
+
+ /**
+ * Estimate a suitable buffer size based on dataset
+ *
+ * A suitable buffer size is the minimum size such that nearest neighbors can be accurately found even at
+ * boundary of splitting plane between pivot points. Therefore assuming points are uniformly distributed in
+ * high dimensional space, it should be approximately the average distance between points.
+ *
+ * Specifically the number of points within a certain radius of a given point is proportionally to the density of
+ * points raised to the effective number of dimensions, of which manifold data points exist on:
+ * R_s = \frac{c}{N_s ** 1/d}
+ * where R_s is the radius, N_s is the number of points, d is effective number of dimension, and c is a constant.
+ *
+ * To estimate R_s_all for entire dataset, we can take samples of the dataset of different size N_s to compute R_s.
+ * We can estimate c and d using linear regression. Lastly we can calculate R_s_all using total number of observation
+ * in dataset.
+ *
+ */
+ def estimateTau(data: RDD[RowWithVector], sampleSize: Array[Int], seed: Long): Double = {
+ val total = data.count()
+
+ // take samples of points for estimation
+ val samples = data.mapPartitionsWithIndex {
+ case (partitionId, itr) =>
+ val rand = new XORShiftRandom(byteswap64(seed ^ partitionId))
+ itr.flatMap {
+ p => sampleSize.zipWithIndex
+ .filter { case (size, _) => rand.nextDouble() * total < size }
+ .map { case (size, index) => (index, p) }
+ }
+ }
+ // compute N_s and R_s pairs
+ val estimators = samples
+ .groupByKey()
+ .map {
+ case (index, points) => (points.size, computeAverageDistance(points))
+ }.collect().distinct
+
+ // collect x and y vectors
+ val x = DenseVector(estimators.map { case (n, _) => math.log(n) })
+ val y = DenseVector(estimators.map { case (_, d) => math.log(d) })
+
+ // estimate log(R_s) = alpha + beta * log(N_s)
+ val xMeanVariance: MeanAndVariance = meanAndVariance(x)
+ val xmean = xMeanVariance.mean
+ val yMeanVariance: MeanAndVariance = meanAndVariance(y)
+ val ymean = yMeanVariance.mean
+
+ val corr = (mean(x *:* y) - xmean * ymean) / math.sqrt((mean(x *:* x) - xmean * xmean) * (mean(y *:* y) - ymean * ymean))
+
+ val beta = corr * yMeanVariance.stdDev / xMeanVariance.stdDev
+ val alpha = ymean - beta * xmean
+ val rs = math.exp(alpha + beta * math.log(total))
+
+ if (beta > 0 || beta.isNaN || rs.isNaN) {
+ val yMax = breeze.linalg.max(y)
+ logger.error(
+ s"""Unable to estimate Tau with positive beta: $beta. This maybe because data is too small.
+ |Setting to $yMax which is the maximum average distance we found in the sample.
+ |This may leads to poor accuracy. Consider manually set bufferSize instead.
+ |You can also try setting balanceThreshold to zero so only metric trees are built.""".stripMargin)
+ yMax
+ } else {
+ // c = alpha, d = - 1 / beta
+ rs / math.sqrt(-1 / beta)
+ }
+ }
+
+ // compute the average distance of nearest neighbors within points using brute-force
+ private[this] def computeAverageDistance(points: Iterable[RowWithVector]): Double = {
+ val distances = points.map {
+ point => points.map(p => p.vector.fastSquaredDistance(point.vector)).filter(_ > 0).min
+ }.map(math.sqrt)
+
+ distances.sum / distances.size
+ }
+
+ /**
+ * Search leaf index used by KNNPartitioner to partition training points
+ *
+ * @param v one training point to partition
+ * @param tree top tree constructed using sampled points
+ * @param acc accumulator used to help determining leaf index
+ * @return leaf/partition index
+ */
+ @tailrec
+ private[knn] def searchIndex(v: RowWithVector, tree: Tree, acc: Int = 0): Int = {
+ tree match {
+ case node: MetricTree =>
+ val leftDistance = node.leftPivot.fastSquaredDistance(v.vector)
+ val rightDistance = node.rightPivot.fastSquaredDistance(v.vector)
+ if (leftDistance < rightDistance) {
+ searchIndex(v, node.leftChild, acc)
+ } else {
+ searchIndex(v, node.rightChild, acc + node.leftChild.leafCount)
+ }
+ case _ => acc // reached leaf
+ }
+ }
+
+ //TODO: Might want to make this tail recursive
+ private[ml] def searchIndices(v: VectorWithNorm, tree: Tree, tau: Double, acc: Int = 0): Seq[Int] = {
+ tree match {
+ case node: MetricTree =>
+ val leftDistance = node.leftPivot.fastDistance(v)
+ val rightDistance = node.rightPivot.fastDistance(v)
+
+ val buffer = new ArrayBuffer[Int]
+ if (leftDistance - rightDistance <= tau) {
+ buffer ++= searchIndices(v, node.leftChild, tau, acc)
+ }
+
+ if (rightDistance - leftDistance <= tau) {
+ buffer ++= searchIndices(v, node.rightChild, tau, acc + node.leftChild.leafCount)
+ }
+
+ buffer
+ case _ => Seq(acc) // reached leaf
+ }
+ }
+
+ /**
+ * Partitioner used to map vector to leaf node which determines the partition it goes to
+ *
+ * @param tree `Tree` used to find leaf
+ */
+ class KNNPartitioner[T <: RowWithVector](tree: Tree) extends Partitioner {
+ override def numPartitions: Int = tree.leafCount
+
+ override def getPartition(key: Any): Int = {
+ key match {
+ case v: RowWithVector => searchIndex(v, tree)
+ case _ => throw new IllegalArgumentException(s"Key must be of type Vector but got: $key")
+ }
+ }
+
+ }
+
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/knn/MetricTree.scala b/ml-core/src/main/scala/org/apache/spark/ml/knn/MetricTree.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ae078c629fbbc8cd542f8de5577840d456d351a2
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/knn/MetricTree.scala
@@ -0,0 +1,398 @@
+//scalastyle:off
+package org.apache.spark.ml.knn
+
+import breeze.linalg._
+import org.apache.spark.ml.knn.KNN._
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.util.random.XORShiftRandom
+
+import scala.collection.mutable
+
+/**
+ * A [[Tree]] is used to store data points used in k-NN search. It represents
+ * a binary tree node. It keeps track of the pivot vector which closely approximate
+ * the center of all vectors within the node. All vectors are within the radius of
+ * distance to the pivot vector. Finally it knows the number of leaves to help
+ * determining partition index.
+ */
+private[ml] abstract class Tree extends Serializable {
+ val leftChild: Tree
+ val rightChild: Tree
+ val size: Int
+ val leafCount: Int
+ val pivot: VectorWithNorm
+ val radius: Double
+
+ def iterator: Iterator[RowWithVector]
+
+ /**
+ * k-NN query using pre-built [[Tree]]
+ * @param v vector to query
+ * @param k number of nearest neighbor
+ * @return a list of neighbor that is nearest to the query vector
+ */
+ def query(v: Vector, k: Int = 1): Iterable[(RowWithVector, Double)] = query(new VectorWithNorm(v), k)
+ def query(v: VectorWithNorm, k: Int): Iterable[(RowWithVector, Double)] = query(new KNNCandidates(v, k)).toIterable
+
+ /**
+ * Refine k-NN candidates using data in this [[Tree]]
+ */
+ private[knn] def query(candidates: KNNCandidates): KNNCandidates
+
+ /**
+ * Compute QueryCost defined as || v.center - q || - r
+ * when >= v.r node can be pruned
+ * for MetricNode this can be used to determine which child does queryVector falls into
+ */
+ private[knn] def distance(candidates: KNNCandidates): Double = distance(candidates.queryVector)
+
+ private[knn] def distance(v: VectorWithNorm): Double =
+ if(pivot.vector.size > 0) pivot.fastDistance(v) else 0.0
+}
+
+private[knn]
+case object Empty extends Tree {
+ override val leftChild = this
+ override val rightChild = this
+ override val size = 0
+ override val leafCount = 0
+ override val pivot = new VectorWithNorm(Vectors.dense(Array.empty[Double]))
+ override val radius = 0.0
+
+ override def iterator: Iterator[RowWithVector] = Iterator.empty
+ override def query(candidates: KNNCandidates): KNNCandidates = candidates
+}
+
+private[knn]
+case class Leaf (data: IndexedSeq[RowWithVector],
+ pivot: VectorWithNorm,
+ radius: Double) extends Tree {
+ override val leftChild = Empty
+ override val rightChild = Empty
+ override val size = data.size
+ override val leafCount = 1
+
+ override def iterator: Iterator[RowWithVector] = data.iterator
+
+ // brute force k-NN search at the leaf
+ override def query(candidates: KNNCandidates): KNNCandidates = {
+ val sorted = data
+ .map{ v => (v, candidates.queryVector.fastDistance(v.vector)) }
+ .sortBy(_._2)
+
+ for((v, d) <- sorted if candidates.notFull || d < candidates.maxDistance)
+ candidates.insert(v, d)
+
+ candidates
+ }
+}
+
+private[knn]
+object Leaf {
+ def apply(data: IndexedSeq[RowWithVector]): Leaf = {
+ val vectors = data.map(_.vector.vector.asBreeze)
+ val (minV, maxV) = vectors.foldLeft((vectors.head, vectors.head)) {
+ case ((accMin, accMax), bv) =>
+ (min(accMin, bv), max(accMax, bv))
+ }
+ val pivot = new VectorWithNorm((minV + maxV) / 2.0)
+ val radius = math.sqrt(squaredDistance(minV, maxV)) / 2.0
+ Leaf(data, pivot, radius)
+ }
+}
+
+/**
+ * A [[MetricTree]] represents a MetricNode where data are split into two partitions: left and right.
+ * There exists two pivot vectors: leftPivot and rightPivot to determine the partitioning.
+ * Pivot vector should be the middle of leftPivot and rightPivot vectors.
+ * Points that is closer to leftPivot than to rightPivot belongs to leftChild and rightChild otherwise.
+ *
+ * During search, because we have information about each child's pivot and radius, we can see if the
+ * hyper-sphere intersects with current candidates sphere. If so, we search the child that has the
+ * most potential (i.e. the child which has the closest pivot).
+ * Once that child has been fully searched, we backtrack to the remaining child and search if necessary.
+ *
+ * This is much more efficient than naive brute force search. However backtracking can take a lot of time
+ * when the number of dimension is high (due to longer time to compute distance and the volume growing much
+ * faster than radius).
+ */
+private[knn]
+case class MetricTree(leftChild: Tree,
+ leftPivot: VectorWithNorm,
+ rightChild: Tree,
+ rightPivot: VectorWithNorm,
+ pivot: VectorWithNorm,
+ radius: Double
+ ) extends Tree {
+ override val size = leftChild.size + rightChild.size
+ override val leafCount = leftChild.leafCount + rightChild.leafCount
+
+ override def iterator: Iterator[RowWithVector] = leftChild.iterator ++ rightChild.iterator
+ override def query(candidates: KNNCandidates): KNNCandidates = {
+ lazy val leftQueryCost = leftChild.distance(candidates)
+ lazy val rightQueryCost = rightChild.distance(candidates)
+ // only query if at least one of the children is worth looking
+ if(candidates.notFull ||
+ leftQueryCost - candidates.maxDistance < leftChild.radius ||
+ rightQueryCost - candidates.maxDistance < rightChild.radius ){
+ val remainingChild = {
+ if (leftQueryCost <= rightQueryCost) {
+ leftChild.query(candidates)
+ rightChild
+ } else {
+ rightChild.query(candidates)
+ leftChild
+ }
+ }
+ // check again to see if the remaining child is still worth looking
+ if (candidates.notFull ||
+ remainingChild.distance(candidates) - candidates.maxDistance < remainingChild.radius) {
+ remainingChild.query(candidates)
+ }
+ }
+ candidates
+ }
+}
+
+object MetricTree {
+ /**
+ * Build a (metric)[[Tree]] that facilitate k-NN query
+ *
+ * @param data vectors that contain all training data
+ * @param seed random number generator seed used in pivot point selecting
+ * @return a [[Tree]] can be used to do k-NN query
+ */
+ def build(data: IndexedSeq[RowWithVector], leafSize: Int = 1, seed: Long = 0L): Tree = {
+ val size = data.size
+ if(size == 0) {
+ Empty
+ } else if(size <= leafSize) {
+ Leaf(data)
+ } else {
+ val rand = new XORShiftRandom(seed)
+ val randomPivot = data(rand.nextInt(size)).vector
+ val leftPivot = data.maxBy(v => randomPivot.fastSquaredDistance(v.vector)).vector
+ if(leftPivot == randomPivot) {
+ // all points are identical (or only one point left)
+ Leaf(data, randomPivot, 0.0)
+ } else {
+ val rightPivot = data.maxBy(v => leftPivot.fastSquaredDistance(v.vector)).vector
+ val pivot = new VectorWithNorm(Vectors.fromBreeze((leftPivot.vector.asBreeze + rightPivot.vector.asBreeze) / 2.0))
+ val radius = math.sqrt(data.map(v => pivot.fastSquaredDistance(v.vector)).max)
+ val (leftPartition, rightPartition) = data.partition{
+ v => leftPivot.fastSquaredDistance(v.vector) < rightPivot.fastSquaredDistance(v.vector)
+ }
+
+ MetricTree(
+ build(leftPartition, leafSize, rand.nextLong()),
+ leftPivot,
+ build(rightPartition, leafSize, rand.nextLong()),
+ rightPivot,
+ pivot,
+ radius
+ )
+ }
+ }
+ }
+}
+
+/**
+ * A [[SpillTree]] represents a SpillNode. Just like [[MetricTree]], it splits data into two partitions.
+ * However, instead of partition data into exactly two halves, it contains a buffer zone with size of tau.
+ * Left child contains all data left to the center plane + tau (in the leftPivot -> rightPivot direction).
+ * Right child contains all data right to the center plane - tau.
+ *
+ * Search doesn't do backtracking but rather adopt a defeatist search where it search the most prominent
+ * child and that child only. The buffer ensures such strategy doesn't result in a poor outcome.
+ */
+private[knn]
+case class SpillTree(leftChild: Tree,
+ leftPivot: VectorWithNorm,
+ rightChild: Tree,
+ rightPivot: VectorWithNorm,
+ pivot: VectorWithNorm,
+ radius: Double,
+ tau: Double,
+ bufferSize: Int
+ ) extends Tree {
+ override val size = leftChild.size + rightChild.size - bufferSize
+ override val leafCount = leftChild.leafCount + rightChild.leafCount
+
+ override def iterator: Iterator[RowWithVector] =
+ leftChild.iterator ++ rightChild.iterator.filter(childFilter(leftPivot, rightPivot))
+
+ override def query(candidates: KNNCandidates): KNNCandidates = {
+ if (size <= candidates.k - candidates.candidates.size) {
+ iterator.foreach(candidates.insert)
+ } else {
+ val leftQueryCost = candidates.queryVector.fastSquaredDistance(leftPivot)
+ val rightQueryCost = candidates.queryVector.fastSquaredDistance(rightPivot)
+
+ (if (leftQueryCost <= rightQueryCost) leftChild else rightChild).query(candidates)
+
+ // fill candidates with points from other child excluding buffer so we don't double count.
+ // depending on K and how high we are in the tree, this can be very expensive and undesirable
+ // TODO: revisit this idea when we do large scale testing
+ if(candidates.notFull) {
+ (if (leftQueryCost <= rightQueryCost) {
+ rightChild.iterator.filter(childFilter(leftPivot, rightPivot))
+ } else {
+ leftChild.iterator.filter(childFilter(rightPivot, leftPivot))
+ }).foreach(candidates.tryInsert)
+ }
+ }
+ candidates
+ }
+
+ private[this] val childFilter: (VectorWithNorm, VectorWithNorm) => RowWithVector => Boolean =
+ (p1, p2) => p => p.vector.fastDistance(p1) - p.vector.fastDistance(p2) > tau
+}
+
+
+object SpillTree {
+ /**
+ * Build a (spill)[[Tree]] that facilitate k-NN query
+ *
+ * @param data vectors that contain all training data
+ * @param tau overlapping size
+ * @param seed random number generators seed used in pivot point selecting
+ * @return a [[Tree]] can be used to do k-NN query
+ */
+ def build(data: IndexedSeq[RowWithVector], leafSize: Int = 1, tau: Double, seed: Long = 0L): Tree = {
+ val size = data.size
+ if (size == 0) {
+ Empty
+ } else if (size <= leafSize) {
+ Leaf(data)
+ } else {
+ val rand = new XORShiftRandom(seed)
+ val randomPivot = data(rand.nextInt(size)).vector
+ val leftPivot = data.maxBy(v => randomPivot.fastSquaredDistance(v.vector)).vector
+ if (leftPivot == randomPivot) {
+ // all points are identical (or only one point left)
+ Leaf(data, randomPivot, 0.0)
+ } else {
+ val rightPivot = data.maxBy(v => leftPivot.fastSquaredDistance(v.vector)).vector
+ val pivot = new VectorWithNorm(Vectors.fromBreeze((leftPivot.vector.asBreeze + rightPivot.vector.asBreeze) / 2.0))
+ val radius = math.sqrt(data.map(v => pivot.fastSquaredDistance(v.vector)).max)
+ val dataWithDistance = data.map(v =>
+ (v, leftPivot.fastDistance(v.vector), rightPivot.fastDistance(v.vector))
+ )
+ val leftPartition = dataWithDistance.filter { case (_, left, right) => left - right <= tau }.map(_._1)
+ val rightPartition = dataWithDistance.filter { case (_, left, right) => right - left <= tau }.map(_._1)
+
+ SpillTree(
+ build(leftPartition, leafSize, tau, rand.nextLong()),
+ leftPivot,
+ build(rightPartition, leafSize, tau, rand.nextLong()),
+ rightPivot,
+ pivot,
+ radius,
+ tau,
+ leftPartition.size + rightPartition.size - size
+ )
+ }
+ }
+ }
+}
+
+object HybridTree {
+ /**
+ * Build a (hybrid-spill) `Tree` that facilitate k-NN query
+ *
+ * @param data vectors that contain all training data
+ * @param seed random number generator seed used in pivot point selecting
+ * @param tau overlapping size
+ * @param rho balance threshold
+ * @return a `Tree` can be used to do k-NN query
+ */
+ //noinspection ScalaStyle
+ def build(data: IndexedSeq[RowWithVector],
+ leafSize: Int = 1,
+ tau: Double,
+ rho: Double = 0.7,
+ seed: Long = 0L): Tree = {
+ val size = data.size
+ if (size == 0) {
+ Empty
+ } else if (size <= leafSize) {
+ Leaf(data)
+ } else {
+ val rand = new XORShiftRandom(seed)
+ val randomPivot = data(rand.nextInt(size)).vector
+ val leftPivot = data.maxBy(v => randomPivot.fastSquaredDistance(v.vector)).vector
+ if (leftPivot == randomPivot) {
+ // all points are identical (or only one point left)
+ Leaf(data, randomPivot, 0.0)
+ } else {
+ val rightPivot = data.maxBy(v => leftPivot.fastSquaredDistance(v.vector)).vector
+ val pivot = new VectorWithNorm(Vectors.fromBreeze((leftPivot.vector.asBreeze + rightPivot.vector.asBreeze) / 2.0))
+ val radius = math.sqrt(data.map(v => pivot.fastSquaredDistance(v.vector)).max)
+ lazy val dataWithDistance = data.map(v =>
+ (v, leftPivot.fastDistance(v.vector), rightPivot.fastDistance(v.vector))
+ )
+ // implemented boundary is parabola (rather than perpendicular plane described in the paper)
+ lazy val leftPartition = dataWithDistance.filter { case (_, left, right) => left - right <= tau }.map(_._1)
+ lazy val rightPartition = dataWithDistance.filter { case (_, left, right) => right - left <= tau }.map(_._1)
+
+ if(rho <= 0.0 || leftPartition.size > size * rho || rightPartition.size > size * rho) {
+ //revert back to metric node
+ val (leftPartition, rightPartition) = data.partition{
+ v => leftPivot.fastSquaredDistance(v.vector) < rightPivot.fastSquaredDistance(v.vector)
+ }
+ MetricTree(
+ build(leftPartition, leafSize, tau, rho, rand.nextLong()),
+ leftPivot,
+ build(rightPartition, leafSize, tau, rho, rand.nextLong()),
+ rightPivot,
+ pivot,
+ radius
+ )
+ } else {
+ SpillTree(
+ build(leftPartition, leafSize, tau, rho, rand.nextLong()),
+ leftPivot,
+ build(rightPartition, leafSize, tau, rho, rand.nextLong()),
+ rightPivot,
+ pivot,
+ radius,
+ tau,
+ leftPartition.size + rightPartition.size - size
+ )
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Structure to maintain search progress/results for a single query vector.
+ * Internally uses a PriorityQueue to maintain a max-heap to keep track of the
+ * next neighbor to evict.
+ *
+ * @param queryVector vector being searched
+ * @param k number of neighbors to return
+ */
+private[knn]
+class KNNCandidates(val queryVector: VectorWithNorm, val k: Int) extends Serializable {
+ private[knn] val candidates = mutable.PriorityQueue.empty[(RowWithVector, Double)] {
+ Ordering.by(_._2)
+ }
+
+ // return the current maximum distance from neighbor to search vector
+ def maxDistance: Double = if(candidates.isEmpty) 0.0 else candidates.head._2
+ // insert evict neighbor if required. however it doesn't make sure the insert improves
+ // search results. it is caller's responsibility to make sure either candidate list
+ // is not full or the inserted neighbor brings the maxDistance down
+ def insert(v: RowWithVector, d: Double): Unit = {
+ while(candidates.size >= k) candidates.dequeue()
+ candidates.enqueue((v, d))
+ }
+ def insert(v: RowWithVector): Unit = insert(v, v.vector.fastDistance(queryVector))
+ def tryInsert(v: RowWithVector): Unit = {
+ val distance = v.vector.fastDistance(queryVector)
+ if(notFull || distance < maxDistance) insert(v, distance)
+ }
+ def toIterable: Iterable[(RowWithVector, Double)] = candidates
+ def notFull: Boolean = candidates.size < k
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/recommendation/NMFLocalIndexEncoder.scala b/ml-core/src/main/scala/org/apache/spark/ml/recommendation/NMFLocalIndexEncoder.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f4aeb908b096d6f5453a99f31621318ca60b58c3
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/recommendation/NMFLocalIndexEncoder.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+/**
+ * Encoder for storing (blockId, localIndex) into a single integer.
+ *
+ * We use the leading bits (including the sign bit) to store the block id and the rest to store
+ * the local index. This is based on the assumption that users/items are approximately evenly
+ * partitioned. With this assumption, we should be able to encode two billion distinct values.
+ *
+ * @param numBlocks number of blocks
+ */
+private[recommendation] class NMFLocalIndexEncoder(numBlocks: Int) extends Serializable {
+
+ require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.")
+
+ private[this] final val numLocalIndexBits =
+ math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31)
+ private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1
+
+ /** Encodes a (blockId, localIndex) into a single integer. */
+ def encode(blockId: Int, localIndex: Int): Int = {
+ require(blockId < numBlocks)
+ require((localIndex & ~localIndexMask) == 0)
+ (blockId << numLocalIndexBits) | localIndex
+ }
+
+ /** Gets the block id from an encoded index. */
+ @inline
+ def blockId(encoded: Int): Int = {
+ encoded >>> numLocalIndexBits
+ }
+
+ /** Gets the local index from an encoded index. */
+ @inline
+ def localIndex(encoded: Int): Int = {
+ encoded & localIndexMask
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/regression/KNNRegression.scala b/ml-core/src/main/scala/org/apache/spark/ml/regression/KNNRegression.scala
new file mode 100644
index 0000000000000000000000000000000000000000..29ca5b1de402ffb640a2990112d23022c8b4151d
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/regression/KNNRegression.scala
@@ -0,0 +1,157 @@
+//scalastyle:off
+package org.apache.spark.ml.regression
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.knn.{KNN, KNNModelParams, KNNParams, Tree}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.{PredictionModel, Predictor}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * [[https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm]] for regression.
+ * The output value is simply the average of the values of its k nearest neighbors.
+ */
+class KNNRegression(override val uid: String) extends Predictor[Vector, KNNRegression, KNNRegressionModel]
+ with KNNParams with HasWeightCol {
+ def this() = this(Identifiable.randomUID("knnr"))
+
+ /** @group setParam */
+ override def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ override def setLabelCol(value: String): this.type = {
+ set(labelCol, value)
+
+ if ($(weightCol).isEmpty) {
+ set(inputCols, Array(value))
+ } else {
+ set(inputCols, Array(value, $(weightCol)))
+ }
+ }
+
+ //fill in default label col
+ setDefault(inputCols, Array($(labelCol)))
+
+ /** @group setWeight */
+ def setWeightCol(value: String): this.type = {
+ set(weightCol, value)
+
+ if (value.isEmpty) {
+ set(inputCols, Array($(labelCol)))
+ } else {
+ set(inputCols, Array($(labelCol), value))
+ }
+ }
+
+ setDefault(weightCol -> "")
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ def setTopTreeSize(value: Int): this.type = set(topTreeSize, value)
+
+ /** @group setParam */
+ def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value)
+
+ /** @group setParam */
+ def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value)
+
+ /** @group setParam */
+ def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value)
+
+ /** @group setParam */
+ def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override protected def train(dataset: Dataset[_]): KNNRegressionModel = {
+ val knnModel = copyValues(new KNN()).fit(dataset)
+ knnModel.toNewRegressionModel(uid)
+ }
+
+ override def fit(dataset: Dataset[_]): KNNRegressionModel = {
+ // Need to overwrite this method because we need to manually overwrite the buffer size
+ // because it is not supposed to stay the same as the Regressor if user sets it to -1.
+ transformSchema(dataset.schema, logging = true)
+ val model = train(dataset)
+ val bufferSize = model.getBufferSize
+ copyValues(model.setParent(this)).setBufferSize(bufferSize)
+ }
+
+ override def copy(extra: ParamMap): KNNRegression = defaultCopy(extra)
+}
+
+class KNNRegressionModel private[ml](
+ override val uid: String,
+ val topTree: Broadcast[Tree],
+ val subTrees: RDD[Tree]
+) extends PredictionModel[Vector, KNNRegressionModel]
+ with KNNModelParams with HasWeightCol with Serializable {
+ require(subTrees.getStorageLevel != StorageLevel.NONE,
+ "KNNModel is not designed to work with Trees that have not been cached")
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ def setBufferSize(value: Double): this.type = set(bufferSize, value)
+
+ //TODO: This can benefit from DataSet API in Spark 1.6
+ override def transformImpl(dataset: Dataset[_]): DataFrame = {
+ val getWeight: Row => Double = {
+ if($(weightCol).isEmpty) {
+ r => 1.0
+ } else {
+ r => r.getDouble(1)
+ }
+ }
+
+ val neighborDataset : RDD[(Long, Array[(Row, Double)])] = transform(dataset, topTree, subTrees)
+ val merged = neighborDataset
+ .map {
+ case (id, labelsDists) =>
+ val (labels, _) = labelsDists.unzip
+ var i = 0
+ var weight = 0.0
+ var sum = 0.0
+ val length = labels.length
+ while (i < length) {
+ val row = labels(i)
+ val w = getWeight(row)
+ sum += row.getDouble(0) * w
+ weight += w
+ i += 1
+ }
+
+ (id, sum / weight)
+ }
+
+ dataset.sqlContext.createDataFrame(
+ dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) }
+ .leftOuterJoin(merged) //make sure we don't lose any observations
+ .map {
+ case (i, (row, value)) => Row.fromSeq(row.toSeq :+ value.get)
+ },
+ transformSchema(dataset.schema)
+ )
+ }
+
+ override def copy(extra: ParamMap): KNNRegressionModel = {
+ val copied = new KNNRegressionModel(uid, topTree, subTrees)
+ copyValues(copied, extra).setParent(parent)
+ }
+
+ override def predict(features: Vector): Double = {
+ val neighborDataset : RDD[(Long, Array[(Row, Double)])] = transform(subTrees.context.parallelize(Seq(features)), topTree, subTrees)
+ val results = neighborDataset.first()._2
+ val labels = results.map(_._1.getDouble(0))
+ labels.sum / labels.length
+ }
+}
\ No newline at end of file
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala
index d60279794a340e6146726d4744f76443678c4e2d..cf5680d0bcbebb7ae06dcc1d70a049a9500ddd96 100644
--- a/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesCore.scala
@@ -80,8 +80,8 @@ object GradientBoostedTreesCore extends Logging{
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
- val leftWeight = leftCount / totalCount.toDouble
- val rightWeight = rightCount / totalCount.toDouble
+ val leftWeight = leftCount / totalCount
+ val rightWeight = rightCount / totalCount
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tuning/BaseRange.scala b/ml-core/src/main/scala/org/apache/spark/ml/tuning/BaseRange.scala
new file mode 100644
index 0000000000000000000000000000000000000000..76ef96404bbb2f80bfea9136b0d892a03372e8f9
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tuning/BaseRange.scala
@@ -0,0 +1,77 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.util.Random
+
+/**
+ * Abstract base class for IntervalRange, ContinueRange and DiscreteRange.
+ */
+abstract class BaseRange() {
+ val rd = new Random()
+
+ /**
+ * Sample a value from a range of values.
+ * @return a value.
+ */
+ def sampleOne(): Double
+}
+
+/** Create a new range with the `start`, `end` and `step` values of this range.
+ *
+ * @param start the start value
+ * @param end the end value
+ * @param step the step value
+ */
+final case class IntervalRange(start: Double, end: Double, step: Double) extends BaseRange {
+ require(end != start, s"Upper boundary $end must not be equal to boundary $start")
+ require(step != 0.0, s"Step must not be equal to 0")
+
+ val paramValues: List[Double] = (BigDecimal(start) to end by step).map(_.toDouble).toList
+
+ /**
+ * Sample a value from a range of values.
+ * @return a value.
+ */
+ override def sampleOne(): Double = {
+ paramValues(rd.nextInt(paramValues.length))
+ }
+}
+
+/** Create a new range with the `lower` and `upper` values of this range.
+ *
+ * @param lower the start value
+ * @param upper the end value
+ */
+final case class ContinueRange(lower: Double, upper: Double) extends BaseRange {
+ require(upper > lower, s"Upper boundary $upper must be greater than lower boundary $lower")
+
+ /**
+ * sample a value from a range of values.
+ * @return a value.
+ */
+ override def sampleOne(): Double = {
+ lower + (upper - lower) * rd.nextDouble()
+ }
+}
+
+/** Create a new range with the discrete values set.
+ *
+ * @param paramValues set of discrete values.
+ */
+final case class DiscreteRange(paramValues: Seq[Double]) extends BaseRange {
+
+ /**
+ * sample a value from a range of values.
+ * @return a value.
+ */
+ override def sampleOne(): Double = {
+ paramValues(rd.nextInt(paramValues.length))
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tuning/ParamSpace.scala b/ml-core/src/main/scala/org/apache/spark/ml/tuning/ParamSpace.scala
new file mode 100644
index 0000000000000000000000000000000000000000..83c511b59d10f0b4008e665bcdf6fc06bb2c6683
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tuning/ParamSpace.scala
@@ -0,0 +1,100 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.{Param, ParamMap}
+
+/**
+ * Class of hyper-parameters space.
+ */
+class ParamSpace() {
+ var paramList: List[ParamType[_ <: AnyVal]] = List()
+ var paramNames: Array[Param[_ <: AnyVal]] = Array()
+
+ /**
+ * Add IntType hyper-parameters.
+ *
+ * @param parent parent object.
+ * @param name param name.
+ * @param valueRange Range of parameter values.
+ */
+ def addIntParam(parent: String, name: String, valueRange: BaseRange): Unit = {
+ val param = IntParamType(valueRange, parent, name)
+ paramList :+= param
+ paramNames :+= param.getParamName
+ }
+
+ /**
+ * Add DoubleType hyper-parameters.
+ *
+ * @param parent parent object.
+ * @param name param name.
+ * @param valueRange Range of parameter values.
+ */
+ def addDoubleParam(parent: String, name: String, valueRange: BaseRange): Unit = {
+ val param = DoubleParmType(valueRange, parent, name)
+ paramList :+= param
+ paramNames :+= param.getParamName
+ }
+
+ private def asDouble(num: Any): Double = {
+ num match {
+ case i: Int => i.toDouble
+ case i: Long => i.toDouble
+ case i: Float => i.toDouble
+ case i: Double => i
+ case _ => throw new Exception(s"type ${num.getClass} is not supported")
+ }
+ }
+
+ /**
+ * Get configuration values from paramMaps.
+ *
+ * @param configs param configurations.
+ * @return param value.
+ */
+ def getConfigsValue(configs: Array[ParamMap]): Array[Vector] = {
+ val values: ArrayBuffer[Vector] = new ArrayBuffer[Vector]
+ for {config <- configs} {
+ var vectorArray: Array[Double] = Array()
+ paramNames.foreach { paramNames =>
+ vectorArray :+= asDouble(config(paramNames))
+ }
+ values.append(Vectors.dense(vectorArray))
+ }
+ values.toArray
+ }
+
+ /**
+ * Get some recommended configurations.
+ *
+ * @param size configuration number.
+ * @return configurations and configuration value vectors.
+ */
+ def getConfigurations(size: Int): (Array[ParamMap], Array[Vector]) = {
+ val configs: ArrayBuffer[ParamMap] = new ArrayBuffer[ParamMap]
+ val values: ArrayBuffer[Vector] = new ArrayBuffer[Vector]
+ for {iter <- 1 to size} {
+ val paramMap = ParamMap.empty
+ var vectorArray: Array[Double] = Array()
+ paramList.foreach(param => {
+ val x = param.giveParamPair()
+ paramMap.put(x)
+ vectorArray :+= asDouble(x.value)
+
+ })
+ configs.append(paramMap)
+ values.append(Vectors.dense(vectorArray))
+ }
+ (configs.toArray, values.toArray)
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/ml/tuning/ParamType.scala b/ml-core/src/main/scala/org/apache/spark/ml/tuning/ParamType.scala
new file mode 100644
index 0000000000000000000000000000000000000000..838d7ea7f75cfed9ab268a04c3b64e9ae96c47c1
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/ml/tuning/ParamType.scala
@@ -0,0 +1,94 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.ml.param.{Param, ParamPair}
+
+/**
+ * Abstract base class for IntParamType and DoubleParmType.
+ */
+abstract class ParamType[T <: AnyVal] {
+ /**
+ * Sample one param from valueRange.
+ * @return a param value.
+ */
+ def sampleOne(): T
+
+ /**
+ * get param name.
+ * @return param name.
+ */
+ def getParamName: Param[T]
+
+ /**
+ * get param name and sample one param.
+ * @return a param and its value.
+ */
+ def giveParamPair(): ParamPair[T]
+}
+
+/**
+ * Param for int type values.
+ *
+ * @param valueRange range of param values.
+ * @param parent parent object.
+ * @param name param name.
+ */
+final case class IntParamType(valueRange: BaseRange, parent: String, name: String)
+ extends ParamType[Int] {
+ val paramName: Param[Int] = new Param(parent, name, "")
+
+ /**
+ * Sample one param from valueRange.
+ * @return a param value.
+ */
+ override def sampleOne(): Int = valueRange.sampleOne().toInt
+
+ /**
+ * get param name.
+ * @return param name.
+ */
+ override def getParamName: Param[Int] = paramName
+
+ /**
+ * get param name and sample one param.
+ * @return a param and its value.
+ */
+ override def giveParamPair(): ParamPair[Int] = ParamPair(getParamName, sampleOne())
+}
+
+/**
+ * Param for Double type values.
+ *
+ * @param valueRange range of param values.
+ * @param parent parent object.
+ * @param name param name.
+ */
+final case class DoubleParmType(valueRange: BaseRange, parent: String, name: String)
+ extends ParamType[Double] {
+ val paramName: Param[Double] = new Param(parent, name, "")
+
+ /**
+ * Sample one param from valueRange.
+ * @return a param value.
+ */
+ override def sampleOne(): Double = valueRange.sampleOne()
+
+ /**
+ * get param name and sample one param.
+ * @return a param and its value.
+ */
+ override def getParamName: Param[Double] = paramName
+
+ /**
+ * get param name and sample one param.
+ * @return a param and its value.
+ */
+ override def giveParamPair(): ParamPair[Double] = ParamPair(paramName, sampleOne())
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthCore.scala b/ml-core/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthCore.scala
new file mode 100644
index 0000000000000000000000000000000000000000..474a8d1b37dbf1e019b81ffcd16c77f14496a51c
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthCore.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import java.{util => ju}
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partitioner, SparkException}
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.fpm.FPGrowth._
+import org.apache.spark.rdd.RDD
+
+@Since("1.3.0")
+object FPGrowthCore extends Logging with Serializable {
+
+ /**
+ * Generates frequent items by filtering the input data using minimal support level.
+ * @param minCount minimum count for frequent itemsets
+ * @param partitioner partitioner used to distribute items
+ * @return array of frequent patterns and their frequencies ordered by their frequencies
+ */
+ private[fpm] def genFreqItems[Item: ClassTag](
+ data: RDD[Array[Item]],
+ minCount: Long,
+ partitioner: Partitioner): Array[(Item, Long)] = {
+ data.flatMap { t =>
+ val uniq = t.toSet
+ if (t.length != uniq.size) {
+ throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
+ }
+ t
+ }.map(v => (v, 1L))
+ .reduceByKey(partitioner, _ + _)
+ .filter(_._2 >= minCount)
+ .collect()
+ .sortBy(-_._2)
+ }
+
+ /**
+ * Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
+ * @param data transactions
+ * @param minCount minimum count for frequent itemsets
+ * @param freqItems frequent items
+ * @param partitioner partitioner used to distribute transactions
+ * @return an RDD of (frequent itemset, count)
+ */
+ private[fpm] def genFreqItemsets[Item: ClassTag](
+ data: RDD[Array[Item]],
+ minCount: Long,
+ freqItems: Array[Item],
+ partitioner: Partitioner): RDD[FreqItemset[Item]] = {
+ val itemToRank = freqItems.zipWithIndex.toMap
+ data.flatMap { transaction =>
+ genCondTransactions(transaction, itemToRank, partitioner)
+ }.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
+ (tree, transaction) => tree.add(transaction, 1L),
+ (tree1, tree2) => tree1.merge(tree2))
+ .flatMap { case (part, tree) =>
+ tree.extract(minCount, x => partitioner.getPartition(x) == part)
+ }.map { case (ranks, count) =>
+ new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
+ }
+ }
+
+ /**
+ * Generates conditional transactions.
+ * @param transaction a transaction
+ * @param itemToRank map from item to their rank
+ * @param partitioner partitioner used to distribute transactions
+ * @return a map of (target partition, conditional transaction)
+ */
+ private[fpm] def genCondTransactions[Item: ClassTag](
+ transaction: Array[Item],
+ itemToRank: Map[Item, Int],
+ partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
+ val output = mutable.Map.empty[Int, Array[Int]]
+ // Filter the basket by frequent items pattern and sort their ranks.
+ val filtered = transaction.flatMap(itemToRank.get)
+ ju.Arrays.sort(filtered)
+ val n = filtered.length
+ var i = n - 1
+ while (i >= 0) {
+ val item = filtered(i)
+ val part = partitioner.getPartition(item)
+ if (!output.contains(part)) {
+ output(part) = filtered.slice(0, i + 1)
+ }
+ i -= 1
+ }
+ output
+ }
+}
diff --git a/ml-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/ml-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 789c2998f9881b38a3e0805f04e45094ce2aa562..f9be017c2c6c8e738202bc41f33d15977b76971d 100644
--- a/ml-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/ml-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -67,7 +67,7 @@ private[fpm] class LocalPrefixSpan(
count >= minCount
}.sorted
// project and recursively call genFreqPatterns
- freqItems.toIterator.flatMap { case (item, count) =>
+ freqItems.iterator.flatMap { case (item, count) =>
val newPrefix = prefix :+ item
val (stat, continue) =
LocalPrefixSpanUtils.getLocalPrefixStat(newPrefix, count, maxPatternLength, target)
diff --git a/ml-core/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanBase.scala b/ml-core/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanBase.scala
index 313dfa5841e9337dfd3dcd2b3deb1946b9284c32..27317a670f7cd1d24c7ac8e37074c54fbc4a7827 100644
--- a/ml-core/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanBase.scala
+++ b/ml-core/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanBase.scala
@@ -137,7 +137,7 @@ import scala.collection.mutable
}
i += 1
}
- prefixes.toIterator
+ prefixes.iterator
}
/**
@@ -180,7 +180,7 @@ import scala.collection.mutable
}
i += 1
}
- prefixes.toIterator
+ prefixes.iterator
}
/** Tests whether this postfix is non-empty. */
diff --git a/ml-core/src/main/scala/org/apache/spark/mllib/knn/KNNUtils.scala b/ml-core/src/main/scala/org/apache/spark/mllib/knn/KNNUtils.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e6eba84979eb2cd29b5c53104776cf3d925b0c60
--- /dev/null
+++ b/ml-core/src/main/scala/org/apache/spark/mllib/knn/KNNUtils.scala
@@ -0,0 +1,21 @@
+//scalastyle:off
+package org.apache.spark.mllib.knn
+
+import org.apache.spark.ml.{linalg => newlinalg}
+import org.apache.spark.mllib.{linalg => oldlinalg}
+import org.apache.spark.mllib.util.MLUtils
+
+object KNNUtils {
+
+ import oldlinalg.VectorImplicits._
+
+ def fastSquaredDistance(
+ v1: newlinalg.Vector,
+ norm1: Double,
+ v2: newlinalg.Vector,
+ norm2: Double,
+ precision: Double = 1e-6): Double = {
+ MLUtils.fastSquaredDistance(v1, norm1, v2, norm2, precision)
+ }
+
+}
diff --git a/ml-kernel-client-core/pom.xml b/ml-kernel-client-core/pom.xml
index a173de6c370e6e2f8ddb5639d26cc9ffac3595a2..4360b23d6fe6eacfbd66dde139613f6288a10704 100644
--- a/ml-kernel-client-core/pom.xml
+++ b/ml-kernel-client-core/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.2.0
+ 3.0.0
4.0.0
boostkit-ml-kernel-client-core_2.12
- 2.2.0
+ 3.0.0
${project.artifactId}
Spark ml core
@@ -16,6 +16,7 @@
org.apache.hbase
hbase-common
1.3.1
+ provided
hadoop-common
@@ -70,6 +71,9 @@
+ -unchecked
+ -deprecation
+ -feature
-dependencyfile
${project.build.directory}/.scala_dependencies
diff --git a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/DataParser.java b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/DataParser.java
index 851e1ccdacfcc211d1ac1cfa7adff8a7ddbd4d5b..150f55a0261d7acb176ad442d72282cf4972e0a4 100644
--- a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/DataParser.java
+++ b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/DataParser.java
@@ -1,3 +1,10 @@
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
package org.apache.spark.ml.feature.df.intf;
import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
diff --git a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/LanguagesConf.java b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/LanguagesConf.java
index 49eed137a0bfe9733a6b03a5c78188817d148ead..bf602fb80ab09b550abc54b8fe58b5f34efca49d 100644
--- a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/LanguagesConf.java
+++ b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/LanguagesConf.java
@@ -1,3 +1,10 @@
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
package org.apache.spark.ml.feature.df.intf;
import java.io.Serializable;
diff --git a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/Segmenter.java b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/Segmenter.java
index fe3d5d3ec8e38724e051e432047a9daa2cceccf3..4176be720888105b3171572c91d4a50bc0b9297b 100644
--- a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/Segmenter.java
+++ b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/Segmenter.java
@@ -1,3 +1,10 @@
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
package org.apache.spark.ml.feature.df.intf;
import java.util.List;
diff --git a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/TextDetail.java b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/TextDetail.java
index 9f2a5e086fc47eef97a83f925804f7b715d9f60e..2300d82f522c22c56a642d9b4e143e5434b8d5f5 100644
--- a/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/TextDetail.java
+++ b/ml-kernel-client-core/src/main/scala/org/apache/spark/ml/feature/df/intf/TextDetail.java
@@ -1,3 +1,10 @@
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
package org.apache.spark.ml.feature.df.intf;
/**
diff --git a/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/clustering/GammaX.scala b/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/clustering/GammaX.scala
index adc79ea61236507f2b019f8ab47f6e1e324efb3c..dd0822f4221c8ddf104bf28dd30abf8766907bdd 100644
--- a/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/clustering/GammaX.scala
+++ b/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/clustering/GammaX.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpanUtils.scala b/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpanUtils.scala
index 0cccaa70ecdb1194bb4c94907166dd90676ed027..0c22abc9e3a5f1251fccc9f405d69eb45d8199e7 100644
--- a/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpanUtils.scala
+++ b/ml-kernel-client-core/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpanUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/pom.xml b/ml-kernel-client/pom.xml
index a56ccec57587b65900fb9bd5873f1edd090d4abe..02df475347a831f76eb2685038b2880260327769 100644
--- a/ml-kernel-client/pom.xml
+++ b/ml-kernel-client/pom.xml
@@ -2,12 +2,12 @@
org.apache.spark
boostkit-ml
- 2.2.0
+ 3.0.0
4.0.0
boostkit-ml-kernel-client_2.12
- 2.2.0
+ 3.0.0
${project.artifactId}
Spark ml core
@@ -17,6 +17,7 @@
boostkit-ml-core_2.12
${project.version}
${spark.version}
+ provided
@@ -35,6 +36,9 @@
+ -unchecked
+ -deprecation
+ -feature
-dependencyfile
${project.build.directory}/.scala_dependencies
diff --git a/ml-kernel-client/src/main/scala/breeze/linalg/DenseMatrixUtil.scala b/ml-kernel-client/src/main/scala/breeze/linalg/DenseMatrixUtil.scala
index b70b4009b7fc3bc59f85658e9f9185a7bccf5e59..cb9915c55f0d27c4ab5108480e98669767e87176 100644
--- a/ml-kernel-client/src/main/scala/breeze/linalg/DenseMatrixUtil.scala
+++ b/ml-kernel-client/src/main/scala/breeze/linalg/DenseMatrixUtil.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/breeze/linalg/DenseVectorUtil.scala b/ml-kernel-client/src/main/scala/breeze/linalg/DenseVectorUtil.scala
index 16b395972f452661e37efac941c5af8eb44dda52..d64f5757ff425d6aef63fe189c62838cf2d01306 100644
--- a/ml-kernel-client/src/main/scala/breeze/linalg/DenseVectorUtil.scala
+++ b/ml-kernel-client/src/main/scala/breeze/linalg/DenseVectorUtil.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/breeze/linalg/blas/Dgemv.scala b/ml-kernel-client/src/main/scala/breeze/linalg/blas/Dgemv.scala
index f06b7d1dccfd15a51322c985159a1932df8d804f..dba1adb18cd09e609029037726f347749379a44a 100644
--- a/ml-kernel-client/src/main/scala/breeze/linalg/blas/Dgemv.scala
+++ b/ml-kernel-client/src/main/scala/breeze/linalg/blas/Dgemv.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/breeze/linalg/blas/Gramian.scala b/ml-kernel-client/src/main/scala/breeze/linalg/blas/Gramian.scala
index 66883604d7f364125fc197d52c0be9b557c2e159..44777879f609001fc3e317839757a8549811cd13 100644
--- a/ml-kernel-client/src/main/scala/breeze/linalg/blas/Gramian.scala
+++ b/ml-kernel-client/src/main/scala/breeze/linalg/blas/Gramian.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/breeze/linalg/blas/YTYUtils.scala b/ml-kernel-client/src/main/scala/breeze/linalg/blas/YTYUtils.scala
index 8c43aa46b44c6e9a08c9edebbd51638b4f876e50..abba27e1f845c4bad481a9de7688ce394d7af510 100644
--- a/ml-kernel-client/src/main/scala/breeze/linalg/blas/YTYUtils.scala
+++ b/ml-kernel-client/src/main/scala/breeze/linalg/blas/YTYUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/breeze/linalg/lapack/EigenDecomposition.scala b/ml-kernel-client/src/main/scala/breeze/linalg/lapack/EigenDecomposition.scala
index 46bab37651eb884b3708a196760577a87d135ddf..4340b029c9c078eba59ffdc530cca85679f0faf6 100644
--- a/ml-kernel-client/src/main/scala/breeze/linalg/lapack/EigenDecomposition.scala
+++ b/ml-kernel-client/src/main/scala/breeze/linalg/lapack/EigenDecomposition.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/breeze/optimize/LBFGSN.scala b/ml-kernel-client/src/main/scala/breeze/optimize/LBFGSN.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5a37d442a348276931e72c7226802daf76824cd9
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/breeze/optimize/LBFGSN.scala
@@ -0,0 +1,35 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package breeze.optimize
+
+import breeze.linalg.DenseVector
+import breeze.optimize.FirstOrderMinimizer.ConvergenceCheck
+import breeze.util.SerializableLogging
+
+class LBFGSN(convergenceCheck: ConvergenceCheck[DenseVector[Double]], m: Int)
+ extends LBFGS[DenseVector[Double]](convergenceCheck, m) with SerializableLogging {
+ def this(
+ maxIter: Int,
+ m: Int,
+ tolerance: Double,
+ absoluteConvergenceCheck: Boolean = true,
+ fValMemory: Int = 2) = {
+ this(LBFGSN.checkConvergence(maxIter, tolerance, absoluteConvergenceCheck, fValMemory), m)
+ }
+}
+
+object LBFGSN {
+ def checkConvergence(
+ maxIter: Int,
+ tol: Double,
+ absoluteConvergenceCheck: Boolean,
+ fValMemory: Int): ConvergenceCheck[DenseVector[Double]] = {
+ null
+ }
+}
diff --git a/ml-kernel-client/src/main/scala/breeze/optimize/OWLQNF.scala b/ml-kernel-client/src/main/scala/breeze/optimize/OWLQNF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1affc3433231f2214f2dfe361ae16c3674178ee9
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/breeze/optimize/OWLQNF.scala
@@ -0,0 +1,25 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package breeze.optimize
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.FirstOrderMinimizer.ConvergenceCheck
+import breeze.util.SerializableLogging
+
+class OWLQNF(convergenceCheck: ConvergenceCheck[BDV[Double]], m: Int, l1reg: Int => Double)
+ extends OWLQN[Int, BDV[Double]](convergenceCheck, m, l1reg) with SerializableLogging {
+ def this(maxIter: Int, m: Int, l1reg: Int => Double, tolerance: Double) =
+ this(OWLQNF.checkConvergence(maxIter, tolerance), m, l1reg)
+}
+
+object OWLQNF {
+ def checkConvergence(maxIter: Int, tolerance: Double): ConvergenceCheck[BDV[Double]] = {
+ null
+ }
+}
diff --git a/ml-kernel-client/src/main/scala/com/huawei/bigdata/alogrithms/isolationforest/IsolationForest.scala b/ml-kernel-client/src/main/scala/com/huawei/bigdata/alogrithms/isolationforest/IsolationForest.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f0ad6e492e87c3382b57329a2d7a75b290b64093
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/com/huawei/bigdata/alogrithms/isolationforest/IsolationForest.scala
@@ -0,0 +1,56 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.huawei.bigdata.alogrithms.isolationforest
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+
+trait IsolationForestParams extends Params {
+ def setNumTrees(value: Int): this.type = null
+ def setMaxInstances(value: Double): this.type = null
+ def setAnomalyRatio(value: Double): this.type = null
+ def setMaxFea(value: Double): this.type = null
+ def setBootstrap(value: Boolean): this.type = null
+ def setAnomalyRatioError(value: Double): this.type = null
+ def setRandomSeed(value: Long): this.type = null
+ def setFeaturesCol(value: String): this.type = null
+ def setPredictionCol(value: String): this.type = null
+ def setScoreCol(value: String): this.type = null
+}
+
+class Tree() extends Serializable {
+}
+
+class IsolationForestModel(
+ override val uid: String,
+ val trees: Array[Tree],
+ private val numSamples: Int)
+ extends Model[IsolationForestModel] with IsolationForestParams {
+ override def copy(extra: ParamMap): IsolationForestModel = null
+ override def transform(dataset: Dataset[_]): DataFrame = null
+ override def transformSchema(schema: StructType): StructType = null
+}
+
+object IsolationForestModel {
+ def saveModel(model: IsolationForestModel, path: String): Unit = {}
+ def loadModel(path: String): IsolationForestModel = null
+}
+
+class IsolationForest(override val uid: String) extends Estimator[IsolationForestModel]
+ with IsolationForestParams with DefaultParamsWritable with Logging {
+ def this() = this(Identifiable.randomUID("if"))
+ override def fit(dataset: Dataset[_]): IsolationForestModel = null
+ override def transformSchema(schema: StructType): StructType = null
+ override def copy(extra: ParamMap): IsolationForest = null
+}
diff --git a/ml-kernel-client/src/main/scala/com/intel/ssg/bdt/nlp/CRF.scala b/ml-kernel-client/src/main/scala/com/intel/ssg/bdt/nlp/CRF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..94575e546c48d32f1e1797b21dcf9b27436e5a29
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/com/intel/ssg/bdt/nlp/CRF.scala
@@ -0,0 +1,156 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package com.intel.ssg.bdt.nlp
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+
+trait Regularization
+
+case object L1 extends Regularization
+
+case object L2 extends Regularization
+
+class CRF private (
+ private var freq: Int,
+ private var regParam: Double,
+ private var maxIterations: Int,
+ private var tolerance: Double,
+ private var regularization: Regularization) extends Serializable with Logging {
+
+ def this() = this(
+ freq = 1,
+ regParam = 0.5,
+ maxIterations = 1000,
+ tolerance = 1E-3,
+ regularization = L2)
+
+ def setRegParam(regParam: Double): this.type = null
+
+ def setFreq(freq: Int): this.type = null
+
+ def setMaxIterations(maxIterations: Int): this.type = null
+
+ def setTolerance(tol: Double): this.type = null
+
+ def setRegularization(reg: Regularization): this.type = null
+
+ def setCalcAccuracy(ca: Boolean): this.type = null
+
+ def setCompLevel(compLevel: Int): this.type = null
+
+ def setGlobalStageIterFraction(globalStageIterFraction: Double): this.type = null
+
+ def setCommFreeSplit(commFreeSplit: Int): this.type = null
+
+ def setCommFreeToleranceFactor(commFreeToleranceFactor: Double): this.type = null
+
+ def setNumThread(numThread: Int): this.type = null
+
+ def runCRF(
+ template: Array[String],
+ trains: RDD[Sequence],
+ testArrayWithLabel: Array[Sequence] = Array[Sequence](),
+ testArrayWithoutLabel: Array[Sequence] = Array[Sequence]()): CRFModel = null
+}
+
+
+case class Sequence (sequence: Array[Token]) extends Serializable {
+ var seqProb = 0.0
+ lazy val candidates = ArrayBuffer.empty[Sequence]
+
+ def setSeqProb(seqProb: Double): Sequence = {
+ this
+ }
+
+ def setCandidates(nBest: ArrayBuffer[Array[Int]],
+ probN: ArrayBuffer[Double],
+ labels: ArrayBuffer[String]): Sequence = {
+ this
+ }
+
+ def Print(): String = {
+ null
+ }
+
+ def nthPrint(k: Int): String = {
+ null
+ }
+
+ def nBestPrint(): String = {
+ null
+ }
+
+ override def toString: String = {
+ null
+ }
+
+ def toArray: Array[Token] = sequence
+
+ def compare(other: Sequence): Int = {
+ null.asInstanceOf[Int]
+ }
+
+ def probPrinter(): String = {
+ null
+ }
+
+}
+
+object Sequence {
+ def deSerializer(s: String): Sequence = {
+ null
+ }
+ def serializer(sequence: Sequence): String = {
+ null
+ }
+}
+
+class Token(
+ val label: String,
+ val tags: Array[String]) extends Serializable {
+ var prob : Array[(String, Double)] = null
+
+ def setProb(probMat: Array[(String, Double)]): Token = {
+ this
+ }
+
+ def probPrinter(): String = {
+ null
+ }
+
+ override def toString: String = {
+ null
+ }
+
+ def compare(other: Token): Int = {
+ null.asInstanceOf[Int]
+ }
+}
+
+object Token {
+
+ def deSerializer(s: String): Token = {
+ null
+ }
+
+ def serializer(token: Token): String = {
+ null
+ }
+
+ def put(label: String, tags: Array[String]): Token = {
+ null
+ }
+
+ def put(tags: Array[String]): Token = {
+ null
+ }
+}
diff --git a/ml-kernel-client/src/main/scala/com/intel/ssg/bdt/nlp/CRFModel.scala b/ml-kernel-client/src/main/scala/com/intel/ssg/bdt/nlp/CRFModel.scala
new file mode 100644
index 0000000000000000000000000000000000000000..eb39fe5ad9115107ba0eeefdd44c506c2d3a8d80
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/com/intel/ssg/bdt/nlp/CRFModel.scala
@@ -0,0 +1,94 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package com.intel.ssg.bdt.nlp
+
+import java.io._
+
+import org.apache.spark.rdd.RDD
+
+trait VerboseMode
+
+case object VerboseLevel1 extends VerboseMode
+
+case object VerboseLevel2 extends VerboseMode
+
+case class CRFModel (
+ head: Array[String],
+ dic: Array[(String, Int)],
+ alpha: Array[Double]) extends Serializable {
+
+ protected def formatVersion = "1.0"
+
+ private var verboseMode: Option[VerboseMode] = None
+
+ private var nBest = 0
+ private var costFactor = 1.0
+
+ def setNBest(nBest: Int): CRFModel = {
+ this
+ }
+
+ def setVerboseMode(mode: VerboseMode): CRFModel = {
+ this
+ }
+
+ def setcostFact(cf: Double): CRFModel = {
+ this
+ }
+
+ override def toString: String = {
+ null
+ }
+
+ def toStringHead: String = {
+ null
+ }
+
+ def toArrayString: Array[String] = {
+ null
+ }
+
+ def predict(tests: RDD[Sequence]): RDD[Sequence] = {
+ null
+ }
+
+ def predict(tests: Array[Sequence]): Array[Sequence] = {
+ null
+ }
+
+ def testCRF(test: Sequence,
+ costFactor: Double,
+ vMode: Option[VerboseMode]): Sequence = {
+ null
+ }
+}
+
+object CRFModel {
+ def load(source: String): CRFModel = {
+ null
+ }
+
+ def loadBinaryFile(path: String): CRFModel = {
+ null
+ }
+
+ def loadArray(source: Array[String]): CRFModel = {
+ null
+ }
+
+ def save(model: CRFModel): String = {
+ null
+ }
+
+ def saveBinaryFile(model: CRFModel, path: String): Unit = {
+ }
+
+ def saveArray(model: CRFModel): Array[String] = {
+ null
+ }
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/StaticUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/StaticUtils.scala
index 971fccfd92407fa7ca882bcfe5cc65fe6201be4f..317cf449089ab4681dd9b70da5c5e349308d004a 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/StaticUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/StaticUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
@@ -17,5 +11,6 @@ package org.apache.spark.ml
object StaticUtils {
val ZERO_LONG: Long = 0L
val ZERO_INT: Int = 0
+ val ONE_INT: Int = 1
val ZERO_DOUBLE: Double = 0.0
}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/clustering/DBSCAN.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/clustering/DBSCAN.scala
index 4c813b8e20ad387d0c60f63216d124fb1438f6e4..8492eea8a4c68aa8b3b0ec7ab007c4d43f12b353 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/clustering/DBSCAN.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/clustering/DBSCAN.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
index c152970a4c862b924b58b3a05512790f6c71656d..961f24963bb3c8c46588777f8c2b13ad90f62fce 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/DecisionTreeBucketizer.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2022. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/EncoderUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/EncoderUtils.scala
new file mode 100644
index 0000000000000000000000000000000000000000..99075ba1a6669d2c39f4100e0baed8b89f97beff
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/EncoderUtils.scala
@@ -0,0 +1,15 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package org.apache.spark.ml.feature
+
+object EncoderUtils {
+
+ val DEFAULT_THREAD_NUM = 40
+
+ def save2PathPar(hdfsPath: String, localPath: String, numT: Int = DEFAULT_THREAD_NUM): Unit = {}
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/SimRank.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/FeatureEncoding.scala
similarity index 30%
rename from ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/SimRank.scala
rename to ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/FeatureEncoding.scala
index b011327ae2aa1bcdcafbb31fe3aa8b117f16a03d..bba1aeb0e246492f598fc014befdf5236550c56d 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/SimRank.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/FeatureEncoding.scala
@@ -6,28 +6,22 @@
* http://www.apache.org/licenses/LICENSE-2.0
*/
-package org.apache.spark.ml.recommendation
+package org.apache.spark.ml.feature
-import org.apache.spark.ml.param.{ParamMap, Params}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.DataFrame
-trait SimRankParams extends Params {}
+class FeatureEncoding extends Serializable{
+ def execute(dataset: DataFrame = null): Unit = {}
-case class SimRankSimilarity(userSimilarity: DataFrame, itemSimilarity: DataFrame)
+ def setMapLoadPath(mapLoadPath: String): this.type = null
-class SimRank(override val uid: String) extends SimRankParams {
- def this() = this(Identifiable.randomUID("SimRank"))
+ def setDataPath(dataPath: String): this.type = null
- def setDamp(value: Double): this.type = null
+ def setOutputFilePath(outputFilePath: String): this.type = null
- def setNumIter(value: Int): this.type = null
+ def setLocalSavePath(localSavePath: String): this.type = null
- def setUserCol(value: String): this.type = null
+ def setEncodeColumns(encodeColumns: String): this.type = null
- def setItemCol(value: String): this.type = null
-
- def computeSimilarity(dataset: Dataset[_]): SimRankSimilarity = null
-
- override def copy(extra: ParamMap): Params = null
+ def setNumThread(numThread: Int): this.type = null
}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
new file mode 100644
index 0000000000000000000000000000000000000000..91d27c51524408f146d3b2a552757a2c8f7b11f7
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -0,0 +1,240 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.ml.{Estimator, Model, PipelineStage}
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable,
+ MLReadable, MLReader, MLWritable, MLWriter}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+trait TargetEncoderParams extends Params {
+ /**
+ * Get the name of the fold column for out-of-fold encoding.
+ * @return name of the fold column
+ */
+ def getFoldCol: String = null
+
+ /**
+ * Get the name of the label column.
+ * @return name of the label column
+ */
+ def getLabelCol: String = null
+
+ /**
+ * Get names of the columns to encode.
+ * @return names of the columns to encode
+ */
+ def getInputCols: Array[String] = null
+
+ /**
+ * Get number of folds for out-of-fold encoding.
+ * @return number of unique fold ids in the fold column
+ */
+ def getNumFolds: Int = 0
+
+ /**
+ * Get the seed to generate fold column if the fold column is not provided.
+ * @return seed to generate fold column
+ */
+ def getFoldSeed: Int = 0
+
+ /**
+ * Get names of the output columns. If not set, they will be the input column names appended with
+ * a "_te" suffix.
+ * @return names of the output columns
+ */
+ def getOutputCols: Array[String] = null
+
+ /**
+ * Get the BlendedAvgInflectionPoint parameter.
+ * @return the BlendedAvgInflectionPoint parameter
+ */
+ def getBlendedAvgInflectionPoint: Double = 0
+
+ /**
+ * Get the BlendedAvgSmoothing parameter.
+ * @return the BlendedAvgSmoothing parameter
+ */
+ def getBlendedAvgSmoothing: Double = 0
+}
+
+trait TargetEncoderBase extends PipelineStage with TargetEncoderParams {
+ override def transformSchema(schema: StructType): StructType = null
+}
+
+/**
+ * TargetEncoder is a feature engineering method for categorical features. It replaces each category
+ * with its target mean, transforming columns of categorical values into columns of Double vectors.
+ * It supports both classification and regression problems. For classification problems, the length
+ * of result vectors equals to number of classes - 1. For regression, the length of result vectors
+ * equals to 1. The implementation handles target leakage by out-of-fold encoding and blended
+ * averaging.
+ * @param uid uid of TargetEncoder
+ */
+class TargetEncoder(override val uid: String)
+ extends Estimator[TargetEncoderModel]
+ with TargetEncoderParams with TargetEncoderBase
+ with DefaultParamsWritable {
+
+ def this() = this(Identifiable.randomUID("TargetEncoder"))
+
+ /**
+ * Fit the given dataset with TargetEncoder.
+ * @param dataset dataset to be fitted
+ * @return TargetEncoderModel
+ */
+ override def fit(dataset: Dataset[_]): TargetEncoderModel = null
+
+ /**
+ * Copy a TargetEncoder instance.
+ * @param extra extra ParamMap
+ * @return TargetEncoder
+ */
+ override def copy(extra: ParamMap): TargetEncoder = null
+
+ /**
+ * Set the name of the label column. The label column should be Double type and contains no
+ * missing value. For classification problem, the value of label column should be in range
+ * [0, numClasses - 1].
+ * @param value name of the label column
+ * @return TargetEncoder
+ */
+ def setLabelCol(value: String): this.type = null
+
+ /**
+ * Set names of the columns to encode. Input columns should be String type. Missing values are
+ * treated as a new category, handled automatically by TargetEncoder.
+ * @param value names of the columns to encode
+ * @return TargetEncoder
+ */
+ def setInputCols(value: Array[String]): this.type = null
+
+ /**
+ * Set names of the output columns. The order of output columns should match input columns. If
+ * not set, they will be input column names appended with a "_te" suffix.
+ * @param value names of the output columns
+ * @return TargetEncoder
+ */
+ def setOutputCols(value: Array[String]): this.type = null
+
+ /**
+ * Set number of folds for out-of-fold encoding. Larger value means stronger regularization. If
+ * fold column is provided in the training DataFrame, numFolds should match the number of unique
+ * fold ids in it. When not provided, a fold column will be automatically created according to
+ * this value. NumFolds should be in range [2, 10], the default value is 4.
+ * @param value number of unique fold ids in the fold column
+ * @return TargetEncoder
+ */
+ def setNumFolds(value: Int): this.type = null
+
+ /**
+ * Set name of the fold column for out-of-fold encoding. The fold column is an Int column with
+ * range [0, numFolds - 1], a category will be encoded with all targets of the same category from
+ * different fold ids. If the column is not provided in training DataFrame, it will be
+ * automatically created. The default value is "fold".
+ * @param value number of unique fold ids in the fold column
+ * @return TargetEncoder
+ */
+ def setFoldCol(value: String): this.type = null
+
+ /**
+ * Set the seed for generating the fold column when it is not provided in the training DataFrame.
+ * @param value seed for generating the fold column
+ * @return TargetEncoder
+ */
+ def setFoldSeed(value: Int): this.type = null
+
+ /**
+ * Set the parameter of the blended average. The bigger number is set, the groups relatively
+ * bigger to the overall data set size will consider the global target value as a component in the
+ * weighted average. The default value is 10.
+ * @param value parameter of the blended average
+ * @return TargetEncoder
+ */
+ def setBlendedAvgInflectionPoint(value: Double): this.type = null
+
+ /**
+ * Set the parameter of blended average. Controls the rate of transition between a group target
+ * value and a global target value. The default value is 20. It should be positive.
+ * @param value parameter of the blended average
+ * @return TargetEncoder
+ */
+ def setBlendedAvgSmoothing(value: Double): this.type = null
+
+ /**
+ * Set the problem type. Supported problem types include "classification" and "regression".
+ * @param value problem type
+ * @return TargetEncoder
+ */
+ def setProblemType(value: String): this.type = null
+}
+
+/**
+ * TargetEncoderModel contains the aggregation information of the training DataFrame.
+ * @param uid uid of TargetEncoderModel
+ * @param ooFoldAggs out-of-fold aggregation
+ * @param globalAggs global aggregation
+ * @param prior target mean
+ */
+class TargetEncoderModel(
+ override val uid: String,
+ val ooFoldAggs: Array[RDD[(Long, (Array[Double], Int))]],
+ val globalAggs: Array[RDD[(Long, (Array[Double], Int))]],
+ val prior: Array[Double])
+ extends Model[TargetEncoderModel] with MLWritable with DefaultParamsWritable
+ with TargetEncoderBase with TargetEncoderParams {
+ /**
+ * Transform the dataset with target encoding. Out-of-fold encoding is only relevant for training
+ * datasets. Thus the transform method has to treat training dataset and test dataset differently.
+ * When TargetEncoder is used inside a ML pipeline, the differentiation is done automatically. But
+ * if a user decides to use TargetEncoder without ML pipeline, use "transformTrainingDataset" and
+ * "transformTestDataset" instead.
+ * @param dataset dataset to be transformed
+ * @return transformed dataset
+ */
+ override def transform(dataset: Dataset[_]): DataFrame = null
+
+ /**
+ * Transform the training dataset with out-of-fold encoding to prevent target leakage. Unseen
+ * categories will be encoded with the target mean of the whole training dataset.
+ * @param dataset training dataset
+ * @return transformed training dataset
+ */
+ def transformTrainingDataset(dataset: Dataset[_]): DataFrame = null
+
+ /**
+ * Encode the test dataset with global category mean.
+ * @param dataset test dataset
+ * @return transformed test dataset
+ */
+ def transformTestDataset(dataset: Dataset[_]): DataFrame = null
+
+ /**
+ * Copy a TargetEncoderModel instance.
+ * @param extra extra ParamMap
+ * @return TargetEncoderModel
+ */
+ override def copy(extra: ParamMap): TargetEncoderModel = null
+
+ /**
+ * Write method.
+ * @return MLWriter
+ */
+ override def write(): MLWriter = null
+}
+
+object TargetEncoderModel extends MLReadable[TargetEncoderModel] {
+ /**
+ * Read method.
+ * @return MLReader[TargetEncoderModel]
+ */
+ override def read: MLReader[TargetEncoderModel] = null
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNN.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNN.scala
index dc7b364660f24b258043a714afe5d2968898c23f..150cabcfe953edb31d6bc3c040ae70d9f4a75681 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNN.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNN.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNNUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNNUtils.scala
index a987b78d7b07866126a7be35d2ee152fd80305c5..889bfe34655ca93c3c0a6bbfb73240d21770392f 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNNUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/neighbors/KNNUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/ALSUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/ALSUtils.scala
index bf33ee639589bb0b9a91088a0c417f88e2aacc67..54afe39c9502761905dad36884a241a5a6673424 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/ALSUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/ALSUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/NMF.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/NMF.scala
new file mode 100644
index 0000000000000000000000000000000000000000..916866e34eb0559d3bad4ab8ebb04f19494dfb90
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/recommendation/NMF.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+//
+package org.apache.spark.ml.recommendation
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.ml.param.shared.{HasBlockSize, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MLWritable, MLWriter}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+
+/**
+ * Common params for NMF and NMFModel.
+ */
+private[recommendation] trait NMFModelParams extends Params with HasPredictionCol
+ with HasBlockSize {
+
+}
+
+/**
+ * Common params for NMF.
+ */
+private[recommendation] trait NMFParams extends NMFModelParams with HasMaxIter with HasRegParam
+ with HasCheckpointInterval with HasSeed {
+
+}
+
+class NMFModel private[ml](
+ override val uid: String,
+ val rank: Int,
+ @transient val userFactors: DataFrame,
+ @transient val itemFactors: DataFrame)
+ extends Model[NMFModel] with NMFModelParams with MLWritable {
+ /** @group setParam */
+ def setUserCol(value: String): this.type = null
+
+ /** @group setParam */
+ def setItemCol(value: String): this.type = null
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = null
+
+ /** @group expertSetParam */
+ def setColdStartStrategy(value: String): this.type = null
+
+ /**
+ * Set block size for stacking input data in matrices.
+ * Default is 4096.
+ *
+ * @group expertSetParam
+ */
+ def setBlockSize(value: Int): this.type = null
+
+ override def transform(dataset: Dataset[_]): DataFrame = null
+
+ override def transformSchema(schema: StructType): StructType = null
+
+ override def copy(extra: ParamMap): NMFModel = null
+
+ override def write: MLWriter = null
+
+ override def toString: String = null
+
+ /**
+ * Returns top `numItems` items recommended for each user, for all users.
+ *
+ * @param numItems max number of recommendations for each user
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ def recommendForAllUsers(numItems: Int): DataFrame = null
+
+ /**
+ * Returns top `numItems` items recommended for each user id in the input data set. Note that if
+ * there are duplicate ids in the input dataset, only one set of recommendations per unique id
+ * will be returned.
+ *
+ * @param dataset a Dataset containing a column of user ids.
+ * The column name must match `userCol`.
+ * @param numItems max number of recommendations for each user.
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = null
+
+ /**
+ * Returns top `numUsers` users recommended for each item, for all items.
+ *
+ * @param numUsers max number of recommendations for each item
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ def recommendForAllItems(numUsers: Int): DataFrame = null
+
+ /**
+ * Returns top `numUsers` users recommended for each item id in the input data set. Note that if
+ * there are duplicate ids in the input dataset, only one set of recommendations per unique id
+ * will be returned.
+ *
+ * @param dataset a Dataset containing a column of item ids.
+ * The column name must match `itemCol`.
+ * @param numUsers max number of recommendations for each item.
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = null
+}
+
+class NMF(override val uid: String) extends Estimator[NMFModel] with NMFParams
+ with DefaultParamsWritable {
+
+ def this() = this(Identifiable.randomUID("nmf"))
+
+ /** @group setParam */
+ def setRank(value: Int): this.type = null
+
+ /** @group setParam */
+ def setNumUserBlocks(value: Int): this.type = null
+
+ /** @group setParam */
+ def setNumItemBlocks(value: Int): this.type = null
+
+ /** @group setParam */
+ def setUserCol(value: String): this.type = null
+
+ /** @group setParam */
+ def setItemCol(value: String): this.type = null
+
+ /** @group setParam */
+ def setRatingCol(value: String): this.type = null
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = null
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = null
+
+ /** @group setParam */
+ def setRegParam(value: Double): this.type = null
+
+ /** @group setParam */
+ def setCheckpointInterval(value: Int): this.type = null
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = null
+
+ /** @group expertSetParam */
+ def setIntermediateStorageLevel(value: String): this.type = null
+
+ /** @group expertSetParam */
+ def setFinalStorageLevel(value: String): this.type = null
+
+ /** @group expertSetParam */
+ def setColdStartStrategy(value: String): this.type = null
+
+ /**
+ * Set block size for stacking input data in matrices.
+ * Default is 4096.
+ *
+ * @group expertSetParam
+ */
+ def setBlockSize(value: Int): this.type = null
+
+ /**
+ * Sets both numUserBlocks and numItemBlocks to the specific value.
+ *
+ * @group setParam
+ */
+ def setNumBlocks(value: Int): this.type = {
+ setNumUserBlocks(value)
+ setNumItemBlocks(value)
+ this
+ }
+
+ override def fit(dataset: Dataset[_]): NMFModel = null
+
+ override def transformSchema(schema: StructType): StructType = null
+
+ override def copy(extra: ParamMap): NMF = null
+}
+
+object NMF extends DefaultParamsReadable[NMF] with Logging {
+ /**
+ * Rating class for better code readability.
+ */
+ case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
+}
+
+object NMFSolver {
+ private[recommendation] def updateFactors[ID](
+ dstIds: Array[ID],
+ srcPtrs: Array[Int],
+ srcEncodedIndices: Array[Int],
+ ratings: Array[Float],
+ srcFactors: Iterable[(Int, Array[Array[Float]])],
+ numSrcBlocks: Int,
+ rank: Int,
+ regParam: Float,
+ srcEncoder: NMFLocalIndexEncoder): Array[Array[Float]] = null
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/DTUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/DTUtils.scala
index 74d74b9eab1dc593d676f0d84a1e110c35efe4d1..c2b94846eaaec301f983dbc82b99395c3ee2e218 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/DTUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/DTUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala
index 1422e2c3eda453fbea099b0b84584a21501bcca1..2faaaac6197cbc3315315316a69239f87f86480a 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesUtil.scala
@@ -1,10 +1,4 @@
// scalastyle:off
-/*
-* Copyright (C) 2022. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala
index d9f186c3af6e7a2cefb43f34d2fb06c4954ffdfc..844ec804460b6d3cfe387092c504f32918e807a5 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tree/impl/RFUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2022. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidator.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2024e5d4ae84a1aef810a713a415a39eed2e4673
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidator.scala
@@ -0,0 +1,66 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasParallelism
+import org.apache.spark.ml.util.{Identifiable, MLWritable, MLWriter}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+class BayesianCrossValidator(override val uid: String)
+ extends Estimator[BayesianCrossValidatorModel]
+ with BayesianCrossValidatorParams with HasParallelism {
+
+ def this() = this(Identifiable.randomUID("bayesCv"))
+
+ override def transformSchema(schema: StructType): StructType = null
+
+ def setEstimator(value: Estimator[_]): this.type = null
+
+ def setEvaluator(value: Evaluator): this.type = null
+
+ def setNumFolds(value: Int): this.type = null
+
+ def setNumIterations(value: Int): this.type = null
+
+ def setSeed(value: Long): this.type = null
+
+ def setEstimatorParamSpace(value: ParamSpace): this.type = null
+
+ def setParallelism(value: Int): this.type = null
+
+ def setThreshold(value: Double): this.type = null
+
+ def setThresholdFlag(value: Boolean): this.type = null
+
+ def getSearchNumber: Int = 0
+
+ def getBestMetric: Double = 0.0
+
+ override def fit(dataset: Dataset[_]): BayesianCrossValidatorModel = null
+
+ override def copy(extra: ParamMap): BayesianCrossValidator = null
+}
+
+class BayesianCrossValidatorModel private[ml](
+ override val uid: String,
+ val bestModel: Model[_])
+ extends Model[BayesianCrossValidatorModel] with BayesianCrossValidatorParams with MLWritable {
+
+ override def write: MLWriter = null
+
+ override def copy(extra: ParamMap): BayesianCrossValidatorModel = null
+
+ override def transform(dataset: Dataset[_]): DataFrame = null
+
+ override def transformSchema(schema: StructType): StructType = null
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/SPCA.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidatorParams.scala
similarity index 33%
rename from ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/SPCA.scala
rename to ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidatorParams.scala
index 6a079b1ab006151006930332e07d0b4150877736..df0ea3ef358689eedc58921dd30df25494aefff5 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/ml/feature/SPCA.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/BayesianCrossValidatorParams.scala
@@ -5,30 +5,30 @@
* the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*/
-package org.apache.spark.ml.feature
+
+package org.apache.spark.ml.tuning
import org.apache.spark.ml.Estimator
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
-import org.apache.spark.sql.Dataset
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.param.Params
+import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.sql.types.StructType
-class SPCA(override val uid: String)
- extends Estimator[PCAModel] with PCAParams with DefaultParamsWritable {
+private[ml] trait BayesianCrossValidatorParams extends HasSeed with Params {
- def this() = this(Identifiable.randomUID("spca"))
+ def getEstimator: Estimator[_] = null
- def setInputCol(value: String): this.type = null
+ def getEvaluator: Evaluator = null
- def setOutputCol(value: String): this.type = null
+ def getNumFolds: Int = 0
- def setK(value: Int): this.type = null
+ def getNumIterations: Int = 0
- def setMode(value: String): this.type = null
+ def getEstimatorParamSpace: ParamSpace = null
- override def fit(dataset: Dataset[_]): PCAModel = null
+ def getThreshold: Double = 0.0
- override def transformSchema(schema: StructType): StructType = null
+ def getThresholdFlag: Boolean = true
- override def copy(extra: ParamMap): SPCA = null
+ protected def transformSchemaImpl(schema: StructType): StructType = null
}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/EI.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/EI.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b4817d25d5900cfca47cbf8b9198d73b75490291
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/EI.scala
@@ -0,0 +1,13 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+private[ml] class EI() {
+ def compute(preds: Array[(Double, Double)], curBest: Double, par: Double): Array[Double] = null
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/RfSurrogate.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/RfSurrogate.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2c07a0f54c7679cf8e322f309d6ac618a7ce8e2a
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/RfSurrogate.scala
@@ -0,0 +1,24 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.sql.SparkSession
+
+class RfSurrogate(val ss: SparkSession, val minimize: Boolean = false) {
+ def curBest: Double = 0.0
+
+ def update(x: Array[Vector], y: Array[Double]): Unit = {
+ }
+
+ def train(): Unit = {
+ }
+
+ def predict(x: Array[Vector]): Array[(Double, Double)] = null
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/Solver.scala b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/Solver.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e34bd9b9bb13f5e63cf2b519e05be6ec45f65c9c
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/ml/tuning/Solver.scala
@@ -0,0 +1,28 @@
+// scalastyle:off header.matches
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SparkSession
+
+class Solver(
+ val ss: SparkSession,
+ val paramSpace: ParamSpace,
+ val minimize: Boolean,
+ val batchSize: Int,
+ val sampleSize: Int) {
+ def getHistory: (Array[Vector], Array[Double]) = null
+
+ def suggest(): Array[ParamMap] = null
+
+ def feed(configs: Array[ParamMap], y: Array[Double]): Unit = { }
+
+ def feed(config: ParamMap, y: Double): Unit = { }
+}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib.clustering/LDAUtilsXOpt.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib.clustering/LDAUtilsXOpt.scala
index e06a3d6619503dc11408b892c73f3aa5d27ea388..483569cebe1efb178d8c3462f8eec6b8f7a105e0 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib.clustering/LDAUtilsXOpt.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib.clustering/LDAUtilsXOpt.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib.linalg.distributed/RowMatrixUtil.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib.linalg.distributed/RowMatrixUtil.scala
index 60f51e63877d76c040c6cedb83f8277711892e0e..b65a5631e693d2bc368ebdfe2dae5dda9e1ffc2f 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib.linalg.distributed/RowMatrixUtil.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib.linalg.distributed/RowMatrixUtil.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/SPCA.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/SPCA.scala
deleted file mode 100644
index f896112df171134d2785519dc6ddcba5a5820a91..0000000000000000000000000000000000000000
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/SPCA.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-// scalastyle:off header.matches
-/*
- * This file to You under the Apache License, Version 2.0;
- * you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- * http://www.apache.org/licenses/LICENSE-2.0
- */
-package org.apache.spark.mllib.feature
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.ml.StaticUtils
-import org.apache.spark.mllib.linalg.{SparseVector, Vector}
-import org.apache.spark.rdd.RDD
-
-class SPCA(val k: Int) extends Logging {
-
- def setMode(mode: String): this.type = null
-
- def setHighDimensionThreshold(highDimensionThreshold: Int): this.type = null
-
- def setNThread(nThread: Int): this.type = null
-
- def fit(sources: RDD[Vector]): PCAModel = null
-
- def evaluate(
- rows: RDD[SparseVector],
- nRows: Long,
- nCols: Int,
- nThread: Int,
- explainedVariance: Array[Double],
- highDimensionThreshold: Int = 10000000): Double = 0.0
-}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala
index e69cc053014419e968afd03a6680ee6f2342ffa2..4bc666032dcf27640ca1f45f325f839cb5986c4d 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/feature/Word2VecSGHS.scala
@@ -1,10 +1,10 @@
-// scalastyle:off
+// scalastyle:off header.matches
/*
-* Copyright (C) 2022. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
package org.apache.spark.mllib.feature
import scala.collection.mutable
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelationUtil.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthUtil.scala
similarity index 33%
rename from ml-kernel-client/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelationUtil.scala
rename to ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthUtil.scala
index 00314ea75a364c8dce2711de0a13a40b18d97ad3..9ea51450929e766c9429197c0c66fbaee6e72fac 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelationUtil.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/FPGrowthUtil.scala
@@ -1,32 +1,37 @@
-// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
+// scalastyle:off
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*/
+package org.apache.spark.mllib.fpm
-package org.apache.spark.mllib.stat.correlation
-
+import org.apache.spark.Partitioner
import org.apache.spark.internal.Logging
-import org.apache.spark.mllib.linalg.{DenseVector, Matrix}
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.rdd.RDD
-object PearsonCorrelationUtil extends Logging {
+import scala.reflect.ClassTag
- def computeDenseVectorCorrelation(rows: RDD[DenseVector]): Matrix = {
+object FPGrowthUtils extends Logging {
+
+ def genFreqItemsetsByOptLevel1[Item: ClassTag](
+ data: RDD[Array[Item]],
+ minCount: Long,
+ freqItems: Array[Item],
+ partitioner: Partitioner,
+ timeLimit1: Double): RDD[FreqItemset[Item]] = {
null
}
- def computeCorrelationMatrixFromUpperCovariance
- (upperCov: Array[Double], n: Int): Matrix = {
+ def genFreqItemsetsByOptLevel2[Item: ClassTag](
+ data: RDD[Array[Item]],
+ minCount: Long,
+ freqItems: Array[Item],
+ partitioner: Partitioner,
+ timeLimit1: Double,
+ timeLimit2: Double): RDD[FreqItemset[Item]] = {
null
}
-
}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanUtils.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanUtils.scala
index 9e971a6161b7d9d1f42a31cf5f68ce6f6cb04b10..eefa71feb63ad94048e75cd67a5456a425e1ef28 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanUtils.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpanUtils.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelationUtil.scala b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/optimization/CostFunOpt.scala
similarity index 44%
rename from ml-kernel-client/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelationUtil.scala
rename to ml-kernel-client/src/main/scala/org/apache/spark/mllib/optimization/CostFunOpt.scala
index 9ea78b1670f389aae8426ff0bf45e31bea525d2e..ac7ea57da888e63fa0afc8edf2b0e28f237da866 100644
--- a/ml-kernel-client/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelationUtil.scala
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/mllib/optimization/CostFunOpt.scala
@@ -1,10 +1,4 @@
// scalastyle:off header.matches
-/*
-* Copyright (C) 2021. Huawei Technologies Co., Ltd.
-* This program is distributed in the hope that it will be useful,
-* but WITHOUT ANY WARRANTY; without even the implied warranty of
-* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
-* */
/*
* This file to You under the Apache License, Version 2.0;
* you may not use this file except in compliance with
@@ -12,18 +6,18 @@
* http://www.apache.org/licenses/LICENSE-2.0
*/
-package org.apache.spark.mllib.stat.correlation
+package org.apache.spark.mllib.optimization
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
-object SpearmanCorrelationUtil {
-
- /**
- * Get a matrix of ranks for the matrix given
- */
- def getRanks(X: RDD[Vector]): RDD[Vector] = {
+object CostFunOpt {
+ def aggregate(
+ data: RDD[(Double, Vector)],
+ dim: Int,
+ gradient: Gradient,
+ bcW: Broadcast[Vector]): (Vector, Double) = {
null
}
-
}
diff --git a/ml-kernel-client/src/main/scala/org/apache/spark/nlp/CRFUtil.scala b/ml-kernel-client/src/main/scala/org/apache/spark/nlp/CRFUtil.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2514f8fbb2c6ed341f7f950eb8e8a60977e7c43e
--- /dev/null
+++ b/ml-kernel-client/src/main/scala/org/apache/spark/nlp/CRFUtil.scala
@@ -0,0 +1,43 @@
+// scalastyle:off
+/*
+ * This file to You under the Apache License, Version 2.0;
+ * you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package org.apache.spark.nlp
+
+import breeze.optimize.DiffFunction
+import com.intel.ssg.bdt.nlp.Tagger
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.optimization.Updater
+import org.apache.spark.rdd.RDD
+import breeze.linalg.{DenseVector => BDV}
+import org.apache.spark.SparkContext
+
+import scala.collection.mutable.ArrayBuffer
+
+class CRFGradientX extends Serializable {
+}
+
+object CRFGradientX {
+ def dataProcess(data: RDD[Tagger], nThread: Int): RDD[Array[ArrayBuffer[Tagger]]] = {
+ null
+ }
+}
+
+class CostFunX(
+ Data: RDD[Array[ArrayBuffer[Tagger]]],
+ gradient: CRFGradientX,
+ updater: Updater,
+ regParam: Double,
+ compLevel: Int,
+ numThread: Int) extends DiffFunction[BDV[Double]] with Logging with Serializable {
+
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ null
+ }
+
+ def setDriverCoreFromSparkConf(sc: SparkContext): Unit = {
+ }
+}
diff --git a/pom.xml b/pom.xml
index 1acb85ecb90e07564f66fc9a7ea36fefe1fcc528..3d546f2771a7139f12bb9b34f1091d6c4e484acc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2,7 +2,7 @@
4.0.0
org.apache.spark
boostkit-ml
- 2.2.0
+ 3.0.0
${project.artifactId}
Spark ml algo
2020
@@ -12,7 +12,7 @@
1.8
1.8
UTF-8
- spark3.1.1
+ spark3.3.1
ml-kernel-client-core
@@ -24,17 +24,20 @@
org.apache.spark
spark-mllib_2.12
- 3.1.1
+ 3.3.1
+ provided
it.unimi.dsi
fastutil
8.3.1
+ provided
org.scalatest
scalatest_2.12
3.2.0
+ provided
@@ -62,7 +65,7 @@
-
+