集群提交lightGBM算法

in 互联网前沿
关注公众号【好便宜】( ID:haopianyi222 ),领红包啦~
阿里云,国内最大的云服务商,注册就送数千元优惠券:https://t.cn/AiQe5A0g
腾讯云,良心云,价格优惠: https://t.cn/AieHwwKl
搬瓦工,CN2 GIA 优质线路,搭梯子、海外建站推荐: https://t.cn/AieHwfX9
[root@hadoop-1-1 ~]# more lgbm.sh
/app/spark2.3/bin/spark-submit \
--master yarn \
--jars /root/external_pkgs/mmlspark-0.15.jar,/root/external_pkgs/lightgbmlib-2.2.200.jar \
--class com.sf.demo.lgmClassifier /root/lgbm_demo.jar
nohup sh lgbm.sh > lgbm_20191226_001.log 2>&1 &
package com.xx.demo

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.StandardScaler
import com.microsoft.ml.spark.LightGBMClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}

object lgmClassifier {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("lgbm_app").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("yarn.nodemanager.vmem-check-enabled", "false")
    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    val input_path = "/user/spark/H2O/data/PimaIndian.csv"
    val data = sparkSession.sqlContext.read.format("csv").option("sep", ",")
      .option("inferSchema", "true")
      .option("header", "false")
      .load(input_path)

    val schemas= Seq("Pregnancles","Glucose","BloodPressure","SkinThickness","Insulin","BMI","DiabetesPedigreeFuction","Age","Outcome")
    val dataset = data.toDF(schemas:_*)

    val vectorAssembler = new VectorAssembler().setInputCols(dataset.columns.filter(!_.contains("Outcome"))).setOutputCol("features")

    val scaler = (new StandardScaler()
      .setInputCol("features")
      .setOutputCol("scaledFeatures")
      .setWithStd(true)
      .setWithMean(false))

    val lgbm = new LightGBMClassifier().setLabelCol("Outcome").setFeaturesCol("scaledFeatures")

    val pipeline = new Pipeline().setStages(Array(vectorAssembler, scaler, lgbm))

    val paramGrid = (new ParamGridBuilder()
      .addGrid(lgbm.learningRate, Array(0.05,0.1))
      .build())

    // Setup the binary classifier evaluator
    val evaluator = (new BinaryClassificationEvaluator()
      .setLabelCol("Outcome")
      .setRawPredictionCol("prediction")
      .setMetricName("areaUnderROC"))

    // Create the Cross Validation pipeline
    val cv = (new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setSeed(0))

    // Split training and test dataset
    val Array(training, test) = dataset.randomSplit(Array(0.8, 0.2), 0)


    val lgbmModel = cv.fit(training)

    val results = lgbmModel.transform(test)

    val auc = evaluator.evaluate(results)
    println("----AUC--------")
    println(s"The model's auc: $auc")

    sparkSession.stop()
  }

}

pom.xml文件信息更新:dependency信息可以参考Maven官网里"Maven"信息的写法。
## mmlspark
https://mvnrepository.com/artifact/Azure/mmlspark/0.15

## lightgbmlib
https://mvnrepository.com/artifact/com.microsoft.ml.lightgbm/lightgbmlib/2.2.200
 <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.62</version>
        </dependency>
        <dependency>
            <groupId>xgboost4j-spark-0.7</groupId>  <!--自定义-->
            <artifactId>util</artifactId>    <!--自定义-->
            <version>0.7</version> <!--自定义-->
            <scope>system</scope> <!--system,类似provided,需要显式提供依赖的jar以后,Maven就不会在Repository中查找它-->
            <systemPath>${basedir}/src/main/lib/xgboost4j-spark-0.7-jar-with-dependencies.jar</systemPath> <!--项目根目录下的lib文件夹下-->
        </dependency>
        <!-- https://mvnrepository.com/artifact/Azure/mmlspark -->
        <dependency>
            <groupId>Azure</groupId>
            <artifactId>mmlspark</artifactId>
            <version>0.15</version>
            <scope>system</scope> <!--system,类似provided,需要显式提供依赖的jar以后,Maven就不会在Repository中查找它-->
            <systemPath>${basedir}/src/main/lib/mmlspark-0.15.jar</systemPath> <!--项目根目录下的lib文件夹下-->
        </dependency>
        <!-- https://mvnrepository.com/artifact/com.microsoft.ml.lightgbm/lightgbmlib -->
        <dependency>
            <groupId>com.microsoft.ml.lightgbm</groupId>  <!--自定义-->
            <artifactId>lightgbmlib</artifactId>    <!--自定义-->
            <version>2.2.200</version> <!--自定义-->
        </dependency>
关注公众号【好便宜】( ID:haopianyi222 ),领红包啦~
阿里云,国内最大的云服务商,注册就送数千元优惠券:https://t.cn/AiQe5A0g
腾讯云,良心云,价格优惠: https://t.cn/AieHwwKl
搬瓦工,CN2 GIA 优质线路,搭梯子、海外建站推荐: https://t.cn/AieHwfX9
扫一扫关注公众号添加购物返利助手,领红包
Comments are closed.

推荐使用阿里云服务器

超多优惠券

服务器最低一折,一年不到100!

朕已阅去看看