From 4ca2c069073e8e7ff13deac47f5a7e6eb65ac914 Mon Sep 17 00:00:00 2001 From: fengyaojie Date: Fri, 29 Dec 2023 16:03:23 +0800 Subject: [PATCH 1/2] use tmp path --- .../org/apache/spark/executor/Executor.scala | 1047 +++++++++++++++++ 1 file changed, 1047 insertions(+) create mode 100644 trustzone/teeos/spark/core/src/main/scala/org/apache/spark/executor/Executor.scala diff --git a/trustzone/teeos/spark/core/src/main/scala/org/apache/spark/executor/Executor.scala b/trustzone/teeos/spark/core/src/main/scala/org/apache/spark/executor/Executor.scala new file mode 100644 index 000000000..771545d51 --- /dev/null +++ b/trustzone/teeos/spark/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -0,0 +1,1047 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler +import java.lang.management.ManagementFactory +import java.net.{URI, URL} +import java.nio.ByteBuffer +import java.util.{Locale, Properties} +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.concurrent.GuardedBy +import javax.ws.rs.core.UriBuilder + +import scala.collection.JavaConverters._ +import scala.collection.immutable +import scala.collection.mutable.{ArrayBuffer, HashMap, Map, WrappedArray} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.slf4j.MDC + +import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.internal.plugin.PluginContainer +import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.metrics.source.JVMCPUSource +import org.apache.spark.resource.ResourceInformation +import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.scheduler._ +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.util._ +import org.apache.spark.util.io.ChunkedByteBuffer + +/** + * Spark executor, backed by a threadpool to run tasks. + * + * This can be used with Mesos, YARN, and the standalone scheduler. + * An internal RPC interface is used for communication with the driver, + * except in the case of Mesos fine-grained mode. + */ +private[spark] class Executor( + executorId: String, + executorHostname: String, + env: SparkEnv, + userClassPath: Seq[URL] = Nil, + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = new SparkUncaughtExceptionHandler, + resources: immutable.Map[String, ResourceInformation]) + extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + + private val executorShutdown = new AtomicBoolean(false) + ShutdownHookManager.addShutdownHook( + () => stop() + ) + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + private val currentArchives: HashMap[String, Long] = new HashMap[String, Long]() + + private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + + private val conf = env.conf + + // No ip or host:port - just hostname + Utils.checkHost(executorHostname) + // must not have port specified. + assert (0 == Utils.parseHostPort(executorHostname)._2) + + // Make sure the local hostname we report matches the cluster scheduler's name for this host + Utils.setCustomHostname(executorHostname) + + if (!isLocal) { + // Setup an uncaught exception handler for non-local mode. + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) + } + + // Start worker thread pool + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused")) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } + private val schemes = conf.get(EXECUTOR_METRICS_FILESYSTEM_SCHEMES) + .toLowerCase(Locale.ROOT).split(",").map(_.trim).filter(_.nonEmpty) + private val executorSource = new ExecutorSource(threadPool, executorId, schemes) + // Pool used for threads that supervise task killing / cancellation + private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") + // For tasks which are in the process of being killed, this map holds the most recently created + // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't + // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding + // the integrity of the map's internal state). The purpose of this map is to prevent the creation + // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to + // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise + // create. The map key is a task id. + private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]() + + val executorMetricsSource = + if (conf.get(METRICS_EXECUTORMETRICS_SOURCE_ENABLED)) { + Some(new ExecutorMetricsSource) + } else { + None + } + + if (!isLocal) { + env.blockManager.initialize(conf.getAppId) + env.metricsSystem.registerSource(executorSource) + env.metricsSystem.registerSource(new JVMCPUSource()) + executorMetricsSource.foreach(_.register(env.metricsSystem)) + env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource) + } else { + // This enable the registration of the executor source in local mode. + // The actual registration happens in SparkContext, + // it cannot be done here as the appId is not available yet + Executor.executorSourceLocalModeOnly = executorSource + } + + // Whether to load classes in user jars before those in Spark jars + private val userClassPathFirst = conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) + + // Whether to monitor killed / interrupted tasks + private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED) + + private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH) + + // Create our ClassLoader + // do this after SparkEnv creation so can access the SecurityManager + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(replClassLoader) + // SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads + // for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too. + env.serializerManager.setDefaultClassLoader(replClassLoader) + + // Max size of direct result. If task result is bigger than this, we use the block manager + // to send the result back. + private val maxDirectResultSize = Math.min( + conf.get(TASK_MAX_DIRECT_RESULT_SIZE), + RpcUtils.maxMessageSizeBytes(conf)) + + private val maxResultSize = conf.get(MAX_RESULT_SIZE) + + // Maintains the list of running tasks. + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. For example, if max failures is 60 and + * heartbeat interval is 10s, then it will try to send heartbeats for up to 600s (10 minutes). + */ + private val HEARTBEAT_MAX_FAILURES = conf.get(EXECUTOR_HEARTBEAT_MAX_FAILURES) + + /** + * Whether to drop empty accumulators from heartbeats sent to the driver. Including the empty + * accumulators (that satisfy isZero) can make the size of the heartbeat message very large. + */ + private val HEARTBEAT_DROP_ZEROES = conf.get(EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES) + + /** + * Interval to send heartbeats, in milliseconds + */ + private val HEARTBEAT_INTERVAL_MS = conf.get(EXECUTOR_HEARTBEAT_INTERVAL) + + /** + * Interval to poll for executor metrics, in milliseconds + */ + private val METRICS_POLLING_INTERVAL_MS = conf.get(EXECUTOR_METRICS_POLLING_INTERVAL) + + private val pollOnHeartbeat = if (METRICS_POLLING_INTERVAL_MS > 0) false else true + + // Poller for the memory metrics. Visible for testing. + private[executor] val metricsPoller = new ExecutorMetricsPoller( + env.memoryManager, + METRICS_POLLING_INTERVAL_MS, + executorMetricsSource) + + // Executor for the heartbeat task. + private val heartbeater = new Heartbeater( + () => Executor.this.reportHeartBeat(), + "executor-heartbeater", + HEARTBEAT_INTERVAL_MS) + + // must be initialized before running startDriverHeartbeat() + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + + /** + * Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + + /** + * Flag to prevent launching new tasks while decommissioned. There could be a race condition + * accessing this, but decommissioning is only intended to help not be a hard stop. + */ + private var decommissioned = false + + heartbeater.start() + + private val appStartTime = conf.getLong("spark.app.startTime", 0) + + // To allow users to distribute plugins and their required files + // specified by --jars, --files and --archives on application submission, those + // jars/files/archives should be downloaded and added to the class loader via + // updateDependencies. This should be done before plugin initialization below + // because executors search plugins from the class loader and initialize them. + private val Seq(initialUserJars, initialUserFiles, initialUserArchives) = + Seq("jar", "file", "archive").map { key => + conf.getOption(s"spark.app.initial.$key.urls").map { urls => + Map(urls.split(",").map(url => (url, appStartTime)): _*) + }.getOrElse(Map.empty) + } + updateDependencies(initialUserFiles, initialUserJars, initialUserArchives) + + // Plugins need to load using a class loader that includes the executor's user classpath. + // Plugins also needs to be initialized after the heartbeater started + // to avoid blocking to send heartbeat (see SPARK-32175). + private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) { + PluginContainer(env, resources.asJava) + } + + metricsPoller.start() + + private[executor] def numRunningTasks: Int = runningTasks.size() + + /** + * Mark an executor for decommissioning and avoid launching new tasks. + */ + private[spark] def decommission(): Unit = { + decommissioned = true + } + + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { + val tr = new TaskRunner(context, taskDescription, plugins) + runningTasks.put(taskDescription.taskId, tr) + threadPool.execute(tr) + if (decommissioned) { + log.error(s"Launching a task while in decommissioned state.") + } + } + + def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { + val taskRunner = runningTasks.get(taskId) + if (taskRunner != null) { + if (taskReaperEnabled) { + val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized { + val shouldCreateReaper = taskReaperForTask.get(taskId) match { + case None => true + case Some(existingReaper) => interruptThread && !existingReaper.interruptThread + } + if (shouldCreateReaper) { + val taskReaper = new TaskReaper( + taskRunner, interruptThread = interruptThread, reason = reason) + taskReaperForTask(taskId) = taskReaper + Some(taskReaper) + } else { + None + } + } + // Execute the TaskReaper from outside of the synchronized block. + maybeNewTaskReaper.foreach(taskReaperPool.execute) + } else { + taskRunner.kill(interruptThread = interruptThread, reason = reason) + } + } + } + + /** + * Function to kill the running tasks in an executor. + * This can be called by executor back-ends to kill the + * tasks instead of taking the JVM down. + * @param interruptThread whether to interrupt the task thread + */ + def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { + runningTasks.keys().asScala.foreach(t => + killTask(t, interruptThread = interruptThread, reason = reason)) + } + + def stop(): Unit = { + if (!executorShutdown.getAndSet(true)) { + env.metricsSystem.report() + try { + metricsPoller.stop() + } catch { + case NonFatal(e) => + logWarning("Unable to stop executor metrics poller", e) + } + try { + heartbeater.stop() + } catch { + case NonFatal(e) => + logWarning("Unable to stop heartbeater", e) + } + threadPool.shutdown() + + // Notify plugins that executor is shutting down so they can terminate cleanly + Utils.withContextClassLoader(replClassLoader) { + plugins.foreach(_.shutdown()) + } + if (!isLocal) { + env.stop() + } + } + } + + /** Returns the total amount of time this JVM process has spent in garbage collection. */ + private def computeTotalGcTime(): Long = { + ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum + } + + class TaskRunner( + execBackend: ExecutorBackend, + private val taskDescription: TaskDescription, + private val plugins: Option[PluginContainer]) + extends Runnable { + + val taskId = taskDescription.taskId + val taskName = taskDescription.name + val threadName = s"Executor task launch worker for $taskName" + val mdcProperties = taskDescription.properties.asScala + .filter(_._1.startsWith("mdc.")).toSeq + + /** If specified, this task has been killed and this option contains the reason. */ + @volatile private var reasonIfKilled: Option[String] = None + + @volatile private var threadId: Long = -1 + + def getThreadId: Long = threadId + + /** Whether this task has been finished. */ + @GuardedBy("TaskRunner.this") + private var finished = false + + def isFinished: Boolean = synchronized { finished } + + /** How much the JVM process has spent in GC when the task starts to run. */ + @volatile var startGCTime: Long = _ + + /** + * The task to run. This will be set in run() by deserializing the task binary coming + * from the driver. Once it is set, it will never be changed. + */ + @volatile var task: Task[Any] = _ + + def kill(interruptThread: Boolean, reason: String): Unit = { + logInfo(s"Executor is trying to kill $taskName, reason: $reason") + reasonIfKilled = Some(reason) + if (task != null) { + synchronized { + if (!finished) { + task.kill(interruptThread, reason) + } + } + } + } + + /** + * Set the finished flag to true and clear the current thread's interrupt status + */ + private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized { + this.finished = true + // SPARK-14234 - Reset the interrupted status of the thread to avoid the + // ClosedByInterruptException during execBackend.statusUpdate which causes + // Executor to crash + Thread.interrupted() + // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there + // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False) + // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup: + notifyAll() + } + + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(TimeUnit.NANOSECONDS.toMillis( + // SPARK-32898: it's possible that a task is killed when taskStartTimeNs has the initial + // value(=0) still. In this case, the executorRunTime should be considered as 0. + if (taskStartTimeNs > 0) System.nanoTime() - taskStartTimeNs else 0)) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + + override def run(): Unit = { + setMDCForTask(taskName, mdcProperties) + threadId = Thread.currentThread.getId + Thread.currentThread.setName(threadName) + val threadMXBean = ManagementFactory.getThreadMXBean + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) + val deserializeStartTimeNs = System.nanoTime() + val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L + Thread.currentThread.setContextClassLoader(replClassLoader) + val ser = env.closureSerializer.newInstance() + logInfo(s"Running $taskName") + execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + var taskStartTimeNs: Long = 0 + var taskStartCpu: Long = 0 + startGCTime = computeTotalGcTime() + var taskStarted: Boolean = false + + try { + // Must be set before updateDependencies() is called, in case fetching dependencies + // requires access to properties contained within (e.g. for access control). + Executor.taskDeserializationProps.set(taskDescription.properties) + + updateDependencies( + taskDescription.addedFiles, taskDescription.addedJars, taskDescription.addedArchives) + task = ser.deserialize[Task[Any]]( + taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) + task.localProperties = taskDescription.properties + task.setTaskMemoryManager(taskMemoryManager) + + // If this task has been killed before we deserialized it, let's quit now. Otherwise, + // continue executing the task. + val killReason = reasonIfKilled + if (killReason.isDefined) { + // Throw an exception rather than returning, because returning within a try{} block + // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl + // exception will be caught by the catch block, leading to an incorrect ExceptionFailure + // for the task. + throw new TaskKilledException(killReason.get) + } + + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. + if (!isLocal) { + logDebug(s"$taskName's epoch is ${task.epoch}") + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } + + metricsPoller.onTaskStart(taskId, task.stageId, task.stageAttemptId) + taskStarted = true + + // Run the actual task and measure its runtime. + taskStartTimeNs = System.nanoTime() + taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L + var threwException = true + val value = Utils.tryWithSafeFinally { + val res = task.run( + taskAttemptId = taskId, + attemptNumber = taskDescription.attemptNumber, + metricsSystem = env.metricsSystem, + resources = taskDescription.resources, + plugins = plugins) + threwException = false + res + } { + val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + + if (freedMemory > 0 && !threwException) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, $taskName" + if (conf.get(UNSAFE_EXCEPTION_ON_MEMORY_LEAK)) { + throw new SparkException(errMsg) + } else { + logWarning(errMsg) + } + } + + if (releasedLocks.nonEmpty && !threwException) { + val errMsg = + s"${releasedLocks.size} block locks were not released by $taskName\n" + + releasedLocks.mkString("[", ", ", "]") + if (conf.get(STORAGE_EXCEPTION_PIN_LEAK)) { + throw new SparkException(errMsg) + } else { + logInfo(errMsg) + } + } + } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"$taskName completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) + } + val taskFinishNs = System.nanoTime() + val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { + threadMXBean.getCurrentThreadCpuTime + } else 0L + + // If the task has been killed, let's fail it. + task.context.killTaskIfInterrupted() + + val resultSer = env.serializer.newInstance() + val beforeSerializationNs = System.nanoTime() + val valueBytes = resultSer.serialize(value) + val afterSerializationNs = System.nanoTime() + + // Deserialization happens in two parts: first, we deserialize a Task object, which + // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. + task.metrics.setExecutorDeserializeTime(TimeUnit.NANOSECONDS.toMillis( + (taskStartTimeNs - deserializeStartTimeNs) + task.executorDeserializeTimeNs)) + task.metrics.setExecutorDeserializeCpuTime( + (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) + // We need to subtract Task.run()'s deserialization time to avoid double-counting + task.metrics.setExecutorRunTime(TimeUnit.NANOSECONDS.toMillis( + (taskFinishNs - taskStartTimeNs) - task.executorDeserializeTimeNs)) + task.metrics.setExecutorCpuTime( + (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) + task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + task.metrics.setResultSerializationTime(TimeUnit.NANOSECONDS.toMillis( + afterSerializationNs - beforeSerializationNs)) + + // Expose task metrics using the Dropwizard metrics system. + // Update task metrics counters + executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime) + executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime) + executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime) + executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime) + executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime) + executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime) + executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME + .inc(task.metrics.shuffleReadMetrics.fetchWaitTime) + executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime) + executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.totalBytesRead) + executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.remoteBytesRead) + executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK + .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk) + executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ + .inc(task.metrics.shuffleReadMetrics.localBytesRead) + executorSource.METRIC_SHUFFLE_RECORDS_READ + .inc(task.metrics.shuffleReadMetrics.recordsRead) + executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED + .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched) + executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED + .inc(task.metrics.shuffleReadMetrics.localBlocksFetched) + executorSource.METRIC_SHUFFLE_BYTES_WRITTEN + .inc(task.metrics.shuffleWriteMetrics.bytesWritten) + executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN + .inc(task.metrics.shuffleWriteMetrics.recordsWritten) + executorSource.METRIC_INPUT_BYTES_READ + .inc(task.metrics.inputMetrics.bytesRead) + executorSource.METRIC_INPUT_RECORDS_READ + .inc(task.metrics.inputMetrics.recordsRead) + executorSource.METRIC_OUTPUT_BYTES_WRITTEN + .inc(task.metrics.outputMetrics.bytesWritten) + executorSource.METRIC_OUTPUT_RECORDS_WRITTEN + .inc(task.metrics.outputMetrics.recordsWritten) + executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize) + executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled) + executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled) + + // Note: accumulator updates must be collected after TaskMetrics is updated + val accumUpdates = task.collectAccumulatorUpdates() + val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId) + // TODO: do not serialize value twice + val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks) + val serializedDirectResult = ser.serialize(directResult) + val resultSize = serializedDirectResult.limit() + + // directSend = sending directly back to the driver + val serializedResult: ByteBuffer = { + if (maxResultSize > 0 && resultSize > maxResultSize) { + logWarning(s"Finished $taskName. Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize > maxDirectResultSize) { + val blockId = TaskResultBlockId(taskId) + env.blockManager.putBytes( + blockId, + new ChunkedByteBuffer(serializedDirectResult.duplicate()), + StorageLevel.MEMORY_AND_DISK_SER) + logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) + } else { + logInfo(s"Finished $taskName. $resultSize bytes result sent to driver") + serializedDirectResult + } + } + + executorSource.SUCCEEDED_TASKS.inc(1L) + setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskSucceeded()) + execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) + } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName, reason: ${t.reason}") + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) + // Here and below, put task metric peaks in a WrappedArray to expose them as a Seq + // without requiring a copy. + val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) + val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq) + plugins.foreach(_.onTaskFailed(reason)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName, reason: $killReason") + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) + val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) + val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq) + plugins.foreach(_.onTaskFailed(reason)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) + + case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnFatalErrorDepth) => + val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"$taskName encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") + } + setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskFailed(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + + case CausedBy(cDE: CommitDeniedException) => + val reason = cDE.toTaskCommitDeniedReason + setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskFailed(reason)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) + + case t: Throwable if env.isStopped => + // Log the expected exception after executor.stop without stack traces + // see: SPARK-19147 + logError(s"Exception in $taskName: ${t.getMessage}") + + case t: Throwable => + // Attempt to exit cleanly by informing the driver of our failure. + // If anything goes wrong (or this was a fatal exception), we will delegate to + // the default uncaught exception handler, which will terminate the Executor. + logError(s"Exception in $taskName", t) + + // SPARK-20904: Do not report failure to driver if if happened during shut down. Because + // libraries may set up shutdown hooks that race with running tasks during shutdown, + // spurious failures may occur and can result in improper accounting in the driver (e.g. + // the task failure would not be ignored if the shutdown happened because of preemption, + // instead of an app issue). + if (!ShutdownHookManager.inShutdown()) { + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) + val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) + + val (taskFailureReason, serializedTaskFailureReason) = { + try { + val ef = new ExceptionFailure(t, accUpdates).withAccums(accums) + .withMetricPeaks(metricPeaks.toSeq) + (ef, ser.serialize(ef)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums) + .withMetricPeaks(metricPeaks.toSeq) + (ef, ser.serialize(ef)) + } + } + setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskFailed(taskFailureReason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason) + } else { + logInfo("Not reporting error to driver during JVM shutdown.") + } + + // Don't forcibly exit unless the exception was inherently fatal, to avoid + // stopping other tasks unnecessarily. + if (Executor.isFatalError(t, killOnFatalErrorDepth)) { + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) + } + } finally { + runningTasks.remove(taskId) + if (taskStarted) { + // This means the task was successfully deserialized, its stageId and stageAttemptId + // are known, and metricsPoller.onTaskStart was called. + metricsPoller.onTaskCompletion(taskId, task.stageId, task.stageAttemptId) + } + } + } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } + } + + private def setMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = { + // make sure we run the task with the user-specified mdc properties only + MDC.clear() + mdc.foreach { case (key, value) => MDC.put(key, value) } + // avoid overriding the takName by the user + MDC.put("mdc.taskName", taskName) + } + + /** + * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally + * sending a Thread.interrupt(), and monitoring the task until it finishes. + * + * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks + * may not be interruptible or may not respond to their "killed" flags being set. If a significant + * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but + * remain running then this can lead to a situation where new jobs and tasks are starved of + * resources that are being used by these zombie tasks. + * + * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie + * tasks. For backwards-compatibility / backportability this component is disabled by default + * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`. + * + * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically + * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers + * in case kill is called twice with different values for the `interrupt` parameter. + * + * Once created, a TaskReaper will run until its supervised task has finished running. If the + * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if + * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely + * if the supervised task never exits. + */ + private class TaskReaper( + taskRunner: TaskRunner, + val interruptThread: Boolean, + val reason: String) + extends Runnable { + + private[this] val taskId: Long = taskRunner.taskId + + private[this] val killPollingIntervalMs: Long = conf.get(TASK_REAPER_POLLING_INTERVAL) + + private[this] val killTimeoutNs: Long = { + TimeUnit.MILLISECONDS.toNanos(conf.get(TASK_REAPER_KILL_TIMEOUT)) + } + + private[this] val takeThreadDump: Boolean = conf.get(TASK_REAPER_THREAD_DUMP) + + override def run(): Unit = { + setMDCForTask(taskRunner.taskName, taskRunner.mdcProperties) + val startTimeNs = System.nanoTime() + def elapsedTimeNs = System.nanoTime() - startTimeNs + def timeoutExceeded(): Boolean = killTimeoutNs > 0 && elapsedTimeNs > killTimeoutNs + try { + // Only attempt to kill the task once. If interruptThread = false then a second kill + // attempt would be a no-op and if interruptThread = true then it may not be safe or + // effective to interrupt multiple times: + taskRunner.kill(interruptThread = interruptThread, reason = reason) + // Monitor the killed task until it exits. The synchronization logic here is complicated + // because we don't want to synchronize on the taskRunner while possibly taking a thread + // dump, but we also need to be careful to avoid races between checking whether the task + // has finished and wait()ing for it to finish. + var finished: Boolean = false + while (!finished && !timeoutExceeded()) { + taskRunner.synchronized { + // We need to synchronize on the TaskRunner while checking whether the task has + // finished in order to avoid a race where the task is marked as finished right after + // we check and before we call wait(). + if (taskRunner.isFinished) { + finished = true + } else { + taskRunner.wait(killPollingIntervalMs) + } + } + if (taskRunner.isFinished) { + finished = true + } else { + val elapsedTimeMs = TimeUnit.NANOSECONDS.toMillis(elapsedTimeNs) + logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms") + if (takeThreadDump) { + try { + Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread => + if (thread.threadName == taskRunner.threadName) { + logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}") + } + } + } catch { + case NonFatal(e) => + logWarning("Exception thrown while obtaining thread dump: ", e) + } + } + } + } + + if (!taskRunner.isFinished && timeoutExceeded()) { + val killTimeoutMs = TimeUnit.NANOSECONDS.toMillis(killTimeoutNs) + if (isLocal) { + logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " + + "not killing JVM because we are running in local mode.") + } else { + // In non-local-mode, the exception thrown here will bubble up to the uncaught exception + // handler and cause the executor JVM to exit. + throw new SparkException( + s"Killing executor JVM because killed task $taskId could not be stopped within " + + s"$killTimeoutMs ms.") + } + } + } finally { + // Clean up entries in the taskReaperForTask map. + taskReaperForTask.synchronized { + taskReaperForTask.get(taskId).foreach { taskReaperInMap => + if (taskReaperInMap eq this) { + taskReaperForTask.remove(taskId) + } else { + // This must have been a TaskReaper where interruptThread == false where a subsequent + // killTask() call for the same task had interruptThread == true and overwrote the + // map entry. + } + } + } + } + } + } + + /** + * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes + * created by the interpreter to the search path + */ + private def createClassLoader(): MutableURLClassLoader = { + // Bootstrap the list of jars with the user class path. + val now = System.currentTimeMillis() + userClassPath.foreach { url => + currentJars(url.getPath().split("/").last) = now + } + + val currentLoader = Utils.getContextOrSparkClassLoader + + // For each of the jars in the jarSet, add them to the class loader. + // We assume each of the files has already been fetched. + val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + } + if (userClassPathFirst) { + new ChildFirstURLClassLoader(urls, currentLoader) + } else { + new MutableURLClassLoader(urls, currentLoader) + } + } + + /** + * If the REPL is in use, add another ClassLoader that will read + * new classes defined by the REPL as the user types code + */ + private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { + val classUri = conf.get("spark.repl.class.uri", null) + if (classUri != null) { + logInfo("Using REPL class URI: " + classUri) + try { + val _userClassPathFirst: java.lang.Boolean = userClassPathFirst + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") + .asInstanceOf[Class[_ <: ClassLoader]] + val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv], + classOf[String], classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst) + } catch { + case _: ClassNotFoundException => + logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") + System.exit(1) + null + } + } else { + parent + } + } + + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies( + newFiles: Map[String, Long], + newJars: Map[String, Long], + newArchives: Map[String, Long]): Unit = { + lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo(s"Fetching $name with timestamp $timestamp") + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newArchives if currentArchives.getOrElse(name, -1L) < timestamp) { + logInfo(s"Fetching $name with timestamp $timestamp") + val sourceURI = new URI(name) + val uriToDownload = UriBuilder.fromUri(sourceURI).fragment(null).build() + val source = Utils.fetchFile(uriToDownload.toString, Utils.createTempDir(), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal, shouldUntar = false) + val dest = new File( + SparkFiles.getRootDirectory(), + if (sourceURI.getFragment != null) sourceURI.getFragment else source.getName) + logInfo( + s"Unpacking an archive $name from ${source.getAbsolutePath} to ${dest.getAbsolutePath}") + Utils.deleteRecursively(dest) + Utils.unpack(source, dest) + currentArchives(name) = timestamp + } + for ((name, timestamp) <- newJars) { + val localName = new URI(name).getPath.split("/").last + val currentTimeStamp = currentJars.get(name) + .orElse(currentJars.get(localName)) + .getOrElse(-1L) + if (currentTimeStamp < timestamp) { + logInfo(s"Fetching $name with timestamp $timestamp") + // Fetch file with useCache mode, close cache for local mode. + val tempFilePath = "/tmp" + Utils.fetchFile(name, new File(tempFilePath), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) + currentJars(name) = timestamp + // Add it to our class loader + val url = new File(tempFilePath, localName).toURI.toURL + if (!urlClassLoader.getURLs().contains(url)) { + logInfo(s"Adding $url to class loader") + urlClassLoader.addURL(url) + } + } + } + } + } + + /** Reports heartbeat and metrics for active tasks to the driver. */ + private def reportHeartBeat(): Unit = { + // list of (task id, accumUpdates) to send back to the driver + val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]() + val curGCTime = computeTotalGcTime() + + if (pollOnHeartbeat) { + metricsPoller.poll() + } + + val executorUpdates = metricsPoller.getExecutorUpdates() + + for (taskRunner <- runningTasks.values().asScala) { + if (taskRunner.task != null) { + taskRunner.task.metrics.mergeShuffleReadMetrics() + taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + val accumulatorsToReport = + if (HEARTBEAT_DROP_ZEROES) { + taskRunner.task.metrics.accumulators().filterNot(_.isZero) + } else { + taskRunner.task.metrics.accumulators() + } + accumUpdates += ((taskRunner.taskId, accumulatorsToReport)) + } + } + + val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId, + executorUpdates) + try { + val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( + message, new RpcTimeout(HEARTBEAT_INTERVAL_MS.millis, EXECUTOR_HEARTBEAT_INTERVAL.key)) + if (response.reregisterBlockManager) { + logInfo("Told to re-register on heartbeat") + env.blockManager.reregister() + } + heartbeatFailures = 0 + } catch { + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } + } + } +} + +private[spark] object Executor { + // This is reserved for internal use by components that need to read task properties before a + // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be + // used instead. + val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties] + + // Used to store executorSource, for local mode only + var executorSourceLocalModeOnly: ExecutorSource = null + + /** + * Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether + * to kill the executor. + * + * @param depthToCheck The max depth of the exception chain we should search for a fatal error. 0 + * means not checking any fatal error (in other words, return false), 1 means + * checking only the exception but not the cause, and so on. This is to avoid + * `StackOverflowError` when hitting a cycle in the exception chain. + */ + def isFatalError(t: Throwable, depthToCheck: Int): Boolean = { + if (depthToCheck <= 0) { + false + } else { + t match { + case _: SparkOutOfMemoryError => false + case e if Utils.isFatalError(e) => true + case e if e.getCause != null => isFatalError(e.getCause, depthToCheck - 1) + case _ => false + } + } + } +} -- Gitee From 503441ee05d0f5e32584333532e8f43608db9c86 Mon Sep 17 00:00:00 2001 From: fengyaojie Date: Fri, 26 Jan 2024 15:38:30 +0800 Subject: [PATCH 2/2] ivy tmp path --- .../org/apache/spark/deploy/SparkSubmit.scala | 1492 +++++++++++++++++ 1 file changed, 1492 insertions(+) create mode 100644 trustzone/teeos/spark/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala diff --git a/trustzone/teeos/spark/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/trustzone/teeos/spark/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala new file mode 100644 index 000000000..33b81c209 --- /dev/null +++ b/trustzone/teeos/spark/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -0,0 +1,1492 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.lang.reflect.{InvocationTargetException, UndeclaredThrowableException} +import java.net.{URI, URL} +import java.security.PrivilegedExceptionAction +import java.text.ParseException +import java.util.{ServiceLoader, UUID} +import java.util.jar.JarInputStream +import javax.ws.rs.core.UriBuilder + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.{Properties, Try} + +import org.apache.commons.lang3.StringUtils +import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.ivy.Ivy +import org.apache.ivy.core.LogOptions +import org.apache.ivy.core.module.descriptor._ +import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId} +import org.apache.ivy.core.report.ResolveReport +import org.apache.ivy.core.resolve.ResolveOptions +import org.apache.ivy.core.retrieve.RetrieveOptions +import org.apache.ivy.core.settings.IvySettings +import org.apache.ivy.plugins.matcher.GlobPatternMatcher +import org.apache.ivy.plugins.repository.file.FileRepository +import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} + +import org.apache.spark._ +import org.apache.spark.api.r.RUtils +import org.apache.spark.deploy.rest._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.UI._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.util._ + +/** + * Whether to submit, kill, or request the status of an application. + * The latter two operations are currently supported only for standalone and Mesos cluster modes. + */ +private[deploy] object SparkSubmitAction extends Enumeration { + type SparkSubmitAction = Value + val SUBMIT, KILL, REQUEST_STATUS, PRINT_VERSION = Value +} + +/** + * Main gateway of launching a Spark application. + * + * This program handles setting up the classpath with relevant Spark dependencies and provides + * a layer over the different cluster managers and deploy modes that Spark supports. + */ +private[spark] class SparkSubmit extends Logging { + + import DependencyUtils._ + import SparkSubmit._ + + def doSubmit(args: Array[String]): Unit = { + // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to + // be reset before the application starts. + val uninitLog = initializeLogIfNecessary(true, silent = true) + + val appArgs = parseArguments(args) + if (appArgs.verbose) { + logInfo(appArgs.toString) + } + appArgs.action match { + case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) + case SparkSubmitAction.KILL => kill(appArgs) + case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.PRINT_VERSION => printVersion() + } + } + + protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) + } + + /** + * Kill an existing submission. + */ + private def kill(args: SparkSubmitArguments): Unit = { + if (RestSubmissionClient.supportsRestClient(args.master)) { + new RestSubmissionClient(args.master) + .killSubmission(args.submissionToKill) + } else { + val sparkConf = args.toSparkConf() + sparkConf.set("spark.master", args.master) + SparkSubmitUtils + .getSubmitOperations(args.master) + .kill(args.submissionToKill, sparkConf) + } + } + + /** + * Request the status of an existing submission. + */ + private def requestStatus(args: SparkSubmitArguments): Unit = { + if (RestSubmissionClient.supportsRestClient(args.master)) { + new RestSubmissionClient(args.master) + .requestSubmissionStatus(args.submissionToRequestStatusFor) + } else { + val sparkConf = args.toSparkConf() + sparkConf.set("spark.master", args.master) + SparkSubmitUtils + .getSubmitOperations(args.master) + .printSubmissionStatus(args.submissionToRequestStatusFor, sparkConf) + } + } + + /** Print version information to the log. */ + private def printVersion(): Unit = { + logInfo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + logInfo("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) + logInfo(s"Branch $SPARK_BRANCH") + logInfo(s"Compiled by user $SPARK_BUILD_USER on $SPARK_BUILD_DATE") + logInfo(s"Revision $SPARK_REVISION") + logInfo(s"Url $SPARK_REPO_URL") + logInfo("Type --help for more information.") + } + + /** + * Submit the application using the provided parameters, ensuring to first wrap + * in a doAs when --proxy-user is specified. + */ + @tailrec + private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { + + def doRunMain(): Unit = { + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMain(args, uninitLog) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + } else { + throw e + } + } + } else { + runMain(args, uninitLog) + } + } + + // In standalone cluster mode, there are two submission gateways: + // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper + // (2) The new REST-based gateway introduced in Spark 1.3 + // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over + // to use the legacy gateway if the master endpoint turns out to be not a REST server. + if (args.isStandaloneCluster && args.useRest) { + try { + logInfo("Running Spark using the REST application submission protocol.") + doRunMain() + } catch { + // Fail over to use the legacy submission gateway + case e: SubmitRestConnectionException => + logWarning(s"Master endpoint ${args.master} was not a REST server. " + + "Falling back to legacy submission gateway instead.") + args.useRest = false + submit(args, false) + } + // In all other modes, just run the main class as prepared + } else { + doRunMain() + } + } + + /** + * Prepare the environment for submitting an application. + * + * @param args the parsed SparkSubmitArguments used for environment preparation. + * @param conf the Hadoop Configuration, this argument will only be set in unit test. + * @return a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * + * Exposed for testing. + */ + private[deploy] def prepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) + : (Seq[String], Seq[String], SparkConf, String) = { + // Return values + val childArgs = new ArrayBuffer[String]() + val childClasspath = new ArrayBuffer[String]() + val sparkConf = args.toSparkConf() + var childMainClass = "" + + // Set the cluster manager + val clusterManager: Int = args.master match { + case "yarn" => YARN + case m if m.startsWith("spark") => STANDALONE + case m if m.startsWith("mesos") => MESOS + case m if m.startsWith("k8s") => KUBERNETES + case m if m.startsWith("local") => LOCAL + case _ => + error("Master must either be yarn or start with spark, mesos, k8s, or local") + -1 + } + + // Set the deploy mode; default is client mode + var deployMode: Int = args.deployMode match { + case "client" | null => CLIENT + case "cluster" => CLUSTER + case _ => + error("Deploy mode must be either client or cluster") + -1 + } + + if (clusterManager == YARN) { + // Make sure YARN is included in our build if we're trying to use it + if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { + error( + "Could not load YARN classes. " + + "This copy of Spark may not have been compiled with YARN support.") + } + } + + if (clusterManager == KUBERNETES) { + args.master = Utils.checkAndGetK8sMasterUrl(args.master) + // Make sure KUBERNETES is included in our build if we're trying to use it + if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { + error( + "Could not load KUBERNETES classes. " + + "This copy of Spark may not have been compiled with KUBERNETES support.") + } + } + + // Fail fast, the following modes are not supported or applicable + (clusterManager, deployMode) match { + case (STANDALONE, CLUSTER) if args.isPython => + error("Cluster deploy mode is currently not supported for python " + + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + error("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") + case (LOCAL, CLUSTER) => + error("Cluster deploy mode is not compatible with master \"local\"") + case (_, CLUSTER) if isShell(args.primaryResource) => + error("Cluster deploy mode is not applicable to Spark shells.") + case (_, CLUSTER) if isSqlShell(args.mainClass) => + error("Cluster deploy mode is not applicable to Spark SQL shell.") + case (_, CLUSTER) if isThriftServer(args.mainClass) => + error("Cluster deploy mode is not applicable to Spark Thrift server.") + case _ => + } + + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER + val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER + val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isKubernetesClient = clusterManager == KUBERNETES && deployMode == CLIENT + val isKubernetesClusterModeDriver = isKubernetesClient && + sparkConf.getBoolean("spark.kubernetes.submitInDriver", false) + + if (!isMesosCluster && !isStandAloneCluster) { + // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files + // too for packages that include Python code + val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) + + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + // In K8s client mode, when in the driver, add resolved jars early as we might need + // them at the submit time for artifact downloading. + // For example we might use the dependencies for downloading + // files from a Hadoop Compatible fs e.g. S3. In this case the user might pass: + // --packages com.amazonaws:aws-java-sdk:1.7.4:org.apache.hadoop:hadoop-aws:2.7.6 + if (isKubernetesClusterModeDriver) { + val loader = getSubmitClassLoader(sparkConf) + for (jar <- resolvedMavenCoordinates.split(",")) { + addJarToClasspath(jar, loader) + } + } else if (isKubernetesCluster) { + // We need this in K8s cluster mode so that we can upload local deps + // via the k8s application, like in cluster mode driver + childClasspath ++= resolvedMavenCoordinates.split(",") + } else { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) + if (args.isPython || isInternal(args.primaryResource)) { + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) + } + } + } + + // install any R packages that may have been passed through --jars or --packages. + // Spark Packages may contain R source code inside the jar. + if (args.isR && !StringUtils.isBlank(args.jars)) { + RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) + } + } + + // update spark config from args + args.toSparkConf(Option(sparkConf)) + val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) + val targetDir = Utils.createTempDir() + + // Kerberos is not supported in standalone mode, and keytab support is not yet available + // in Mesos cluster mode. + if (clusterManager != STANDALONE + && !isMesosCluster + && args.principal != null + && args.keytab != null) { + // If client mode, make sure the keytab is just a local path. + if (deployMode == CLIENT && Utils.isLocalUri(args.keytab)) { + args.keytab = new URI(args.keytab).getPath() + } + + if (!Utils.isLocalUri(args.keytab)) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } + } + + // Resolve glob path for different resources. + args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull + args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull + args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + + lazy val secMgr = new SecurityManager(sparkConf) + + // In client mode, download remote files. + var localPrimaryResource: String = null + var localJars: String = null + var localPyFiles: String = null + if (deployMode == CLIENT) { + localPrimaryResource = Option(args.primaryResource).map { + downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull + localJars = Option(args.jars).map { + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull + localPyFiles = Option(args.pyFiles).map { + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull + + if (isKubernetesClusterModeDriver) { + // Replace with the downloaded local jar path to avoid propagating hadoop compatible uris. + // Executors will get the jars from the Spark file server. + // Explicitly download the related files here + args.jars = localJars + val filesLocalFiles = Option(args.files).map { + downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + }.orNull + val archiveLocalFiles = Option(args.archives).map { uris => + val resolvedUris = Utils.stringToSeq(uris).map(Utils.resolveURI) + val localArchives = downloadFileList( + resolvedUris.map( + UriBuilder.fromUri(_).fragment(null).build().toString).mkString(","), + targetDir, sparkConf, hadoopConf, secMgr) + + // SPARK-33748: this mimics the behaviour of Yarn cluster mode. If the driver is running + // in cluster mode, the archives should be available in the driver's current working + // directory too. + Utils.stringToSeq(localArchives).map(Utils.resolveURI).zip(resolvedUris).map { + case (localArchive, resolvedUri) => + val source = new File(localArchive.getPath) + val dest = new File( + ".", + if (resolvedUri.getFragment != null) resolvedUri.getFragment else source.getName) + logInfo( + s"Unpacking an archive $resolvedUri " + + s"from ${source.getAbsolutePath} to ${dest.getAbsolutePath}") + Utils.deleteRecursively(dest) + Utils.unpack(source, dest) + + // Keep the URIs of local files with the given fragments. + UriBuilder.fromUri( + localArchive).fragment(resolvedUri.getFragment).build().toString + }.mkString(",") + }.orNull + args.files = filesLocalFiles + args.archives = archiveLocalFiles + args.pyFiles = localPyFiles + } + } + + // When running in YARN, for some remote resources with scheme: + // 1. Hadoop FileSystem doesn't support them. + // 2. We explicitly bypass Hadoop FileSystem with "spark.yarn.dist.forceDownloadSchemes". + // We will download them to local disk prior to add to YARN's distributed cache. + // For yarn client mode, since we already download them with above code, so we only need to + // figure out the local path and replace the remote one. + if (clusterManager == YARN) { + val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) + + def shouldDownload(scheme: String): Boolean = { + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || + Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure + } + + def downloadResource(resource: String): String = { + val uri = Utils.resolveURI(resource) + uri.getScheme match { + case "local" | "file" => resource + case e if shouldDownload(e) => + val file = new File(targetDir, new Path(uri).getName) + if (file.exists()) { + file.toURI.toString + } else { + downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + } + case _ => uri.toString + } + } + + args.primaryResource = Option(args.primaryResource).map { downloadResource }.orNull + args.files = Option(args.files).map { files => + Utils.stringToSeq(files).map(downloadResource).mkString(",") + }.orNull + args.pyFiles = Option(args.pyFiles).map { pyFiles => + Utils.stringToSeq(pyFiles).map(downloadResource).mkString(",") + }.orNull + args.jars = Option(args.jars).map { jars => + Utils.stringToSeq(jars).map(downloadResource).mkString(",") + }.orNull + args.archives = Option(args.archives).map { archives => + Utils.stringToSeq(archives).map(downloadResource).mkString(",") + }.orNull + } + + // At this point, we have attempted to download all remote resources. + // Now we try to resolve the main class if our primary resource is a JAR. + if (args.mainClass == null && !args.isPython && !args.isR) { + try { + val uri = new URI( + Option(localPrimaryResource).getOrElse(args.primaryResource) + ) + val fs = FileSystem.get(uri, hadoopConf) + + Utils.tryWithResource(new JarInputStream(fs.open(new Path(uri)))) { jar => + args.mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class") + } + } catch { + case e: Throwable => + error( + s"Failed to get main class in JAR with error '${e.getMessage}'. " + + " Please specify one with --class." + ) + } + + if (args.mainClass == null) { + // If we still can't figure out the main class at this point, blow up. + error("No main class set in JAR; please specify one with --class.") + } + } + + // If we're running a python app, set the main class to our specific python runner + if (args.isPython && deployMode == CLIENT) { + if (args.primaryResource == PYSPARK_SHELL) { + args.mainClass = "org.apache.spark.api.python.PythonGatewayServer" + } else { + // If a python file is provided, add it to the child arguments and list of files to deploy. + // Usage: PythonAppRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.PythonRunner" + args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs + } + } + + // Non-PySpark applications can need Python dependencies. + if (deployMode == CLIENT && clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) + } + + if (localPyFiles != null) { + sparkConf.set(SUBMIT_PYTHON_FILES, localPyFiles.split(",").toSeq) + } + + // In YARN mode for an R app, add the SparkR package archive and the R package + // archive containing all of the built R libraries to archives so that they can + // be distributed with the job + if (args.isR && clusterManager == YARN) { + val sparkRPackagePath = RUtils.localSparkRPackagePath + if (sparkRPackagePath.isEmpty) { + error("SPARK_HOME does not exist for R application in YARN mode.") + } + val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!sparkRPackageFile.exists()) { + error(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + } + val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString + + // Distribute the SparkR package. + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, sparkRPackageURI + "#sparkr") + + // Distribute the R package archive containing all the built R packages. + if (!RUtils.rPackages.isEmpty) { + val rPackageFile = + RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + error("Failed to zip all the built R packages.") + } + + val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString + // Assigns a symbol link name "rpkg" to the shipped package. + args.archives = mergeFileLists(args.archives, rPackageURI + "#rpkg") + } + } + + // TODO: Support distributing R packages with standalone cluster + if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { + error("Distributing R packages with standalone cluster is not supported.") + } + + // TODO: Support distributing R packages with mesos cluster + if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { + error("Distributing R packages with mesos cluster is not supported.") + } + + // If we're running an R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If an R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(localPrimaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster && args.isR) { + // In yarn-cluster mode for an R app, add primary resource to files + // that can be distributed with the job + args.files = mergeFileLists(args.files, args.primaryResource) + } + + // Special flag to avoid deprecation warnings at the client + sys.props("SPARK_SUBMIT") = "true" + + // A list of rules to map each argument to system properties or command-line options in + // each deploy mode; we iterate through these below + val options = List[OptionAssigner]( + + // All cluster managers + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = SUBMIT_DEPLOY_MODE.key), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.jars.ivy"), + OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, + confKey = DRIVER_MEMORY.key), + OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = DRIVER_CLASS_PATH.key), + OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = DRIVER_JAVA_OPTIONS.key), + OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = DRIVER_LIBRARY_PATH.key), + OptionAssigner(args.principal, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = PRINCIPAL.key), + OptionAssigner(args.keytab, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = KEYTAB.key), + OptionAssigner(args.pyFiles, ALL_CLUSTER_MGRS, CLUSTER, confKey = SUBMIT_PYTHON_FILES.key), + + // Propagate attributes for dependency resolution at the driver side + OptionAssigner(args.packages, STANDALONE | MESOS | KUBERNETES, + CLUSTER, confKey = "spark.jars.packages"), + OptionAssigner(args.repositories, STANDALONE | MESOS | KUBERNETES, + CLUSTER, confKey = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS | KUBERNETES, + CLUSTER, confKey = "spark.jars.ivy"), + OptionAssigner(args.packagesExclusions, STANDALONE | MESOS | KUBERNETES, + CLUSTER, confKey = "spark.jars.excludes"), + + // Yarn only + OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles", + mergeFn = Some(mergeFileLists(_, _))), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars", + mergeFn = Some(mergeFileLists(_, _))), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files", + mergeFn = Some(mergeFileLists(_, _))), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives", + mergeFn = Some(mergeFileLists(_, _))), + + // Other options + OptionAssigner(args.numExecutors, YARN | KUBERNETES, ALL_DEPLOY_MODES, + confKey = EXECUTOR_INSTANCES.key), + OptionAssigner(args.executorCores, STANDALONE | YARN | KUBERNETES, ALL_DEPLOY_MODES, + confKey = EXECUTOR_CORES.key), + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN | KUBERNETES, ALL_DEPLOY_MODES, + confKey = EXECUTOR_MEMORY.key), + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = CORES_MAX.key), + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = FILES.key), + OptionAssigner(args.archives, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = ARCHIVES.key), + OptionAssigner(args.jars, LOCAL, CLIENT, confKey = JARS.key), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = JARS.key), + OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, + confKey = DRIVER_MEMORY.key), + OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, + confKey = DRIVER_CORES.key), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + confKey = DRIVER_SUPERVISE.key), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, confKey = "spark.jars.ivy"), + + // An internal option used only for spark-shell to add user jars to repl's classloader, + // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to + // remote jars, so adding a new option to only specify local jars for spark-shell internally. + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.repl.local.jars") + ) + + // In client mode, launch the application main class directly + // In addition, add the main application jar and any added jars (if any) to the classpath + if (deployMode == CLIENT) { + childMainClass = args.mainClass + if (localPrimaryResource != null && isUserJar(localPrimaryResource)) { + childClasspath += localPrimaryResource + } + if (localJars != null) { childClasspath ++= localJars.split(",") } + } + // Add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be + // added to the classpath of YARN client. + if (isYarnCluster) { + if (isUserJar(args.primaryResource)) { + childClasspath += args.primaryResource + } + if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { + if (args.childArgs != null) { childArgs ++= args.childArgs } + } + + // Map all arguments to command-line options or system properties for our chosen mode + for (opt <- options) { + if (opt.value != null && + (deployMode & opt.deployMode) != 0 && + (clusterManager & opt.clusterManager) != 0) { + if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } + if (opt.confKey != null) { + if (opt.mergeFn.isDefined && sparkConf.contains(opt.confKey)) { + sparkConf.set(opt.confKey, opt.mergeFn.get.apply(sparkConf.get(opt.confKey), opt.value)) + } else { + sparkConf.set(opt.confKey, opt.value) + } + } + } + } + + // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. + if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { + sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true) + } + + // Add the application jar automatically so the user doesn't have to call sc.addJar + // For YARN cluster mode, the jar is already distributed on each node as "app.jar" + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { + var jars = sparkConf.get(JARS) + if (isUserJar(args.primaryResource)) { + jars = jars ++ Seq(args.primaryResource) + } + sparkConf.set(JARS, jars) + } + + // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). + // All Spark parameters are expected to be passed to the client through system properties. + if (args.isStandaloneCluster) { + if (args.useRest) { + childMainClass = REST_CLUSTER_SUBMIT_CLASS + childArgs += (args.primaryResource, args.mainClass) + } else { + // In legacy standalone cluster mode, use Client as a wrapper around the user class + childMainClass = STANDALONE_CLUSTER_SUBMIT_CLASS + if (args.supervise) { childArgs += "--supervise" } + Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } + Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } + childArgs += "launch" + childArgs += (args.master, args.primaryResource, args.mainClass) + } + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN) { + if (args.isPython) { + sparkConf.set("spark.yarn.isPython", "true") + } + } + + if ((clusterManager == MESOS || clusterManager == KUBERNETES) + && UserGroupInformation.isSecurityEnabled) { + setRMPrincipal(sparkConf) + } + + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class + if (isYarnCluster) { + childMainClass = YARN_CLUSTER_SUBMIT_CLASS + if (args.isPython) { + childArgs += ("--primary-py-file", args.primaryResource) + childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") + } else { + if (args.primaryResource != SparkLauncher.NO_RESOURCE) { + childArgs += ("--jar", args.primaryResource) + } + childArgs += ("--class", args.mainClass) + } + if (args.childArgs != null) { + args.childArgs.foreach { arg => childArgs += ("--arg", arg) } + } + } + + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = REST_CLUSTER_SUBMIT_CLASS + if (args.isPython) { + // Second argument is main class + childArgs += (args.primaryResource, "") + if (args.pyFiles != null) { + sparkConf.set(SUBMIT_PYTHON_FILES, args.pyFiles.split(",").toSeq) + } + } else if (args.isR) { + // Second argument is main class + childArgs += (args.primaryResource, "") + } else { + childArgs += (args.primaryResource, args.mainClass) + } + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + + if (isKubernetesCluster) { + childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS + if (args.primaryResource != SparkLauncher.NO_RESOURCE) { + if (args.isPython) { + childArgs ++= Array("--primary-py-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + childArgs ++= Array("--primary-r-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.RRunner") + } + else { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + childArgs ++= Array("--main-class", args.mainClass) + } + } else { + childArgs ++= Array("--main-class", args.mainClass) + } + if (args.childArgs != null) { + args.childArgs.foreach { arg => + childArgs += ("--arg", arg) + } + } + // Pass the proxyUser to the k8s app so it is possible to add it to the driver args + if (args.proxyUser != null) { + childArgs += ("--proxy-user", args.proxyUser) + } + } + + // Load any properties specified through --conf and the default properties file + for ((k, v) <- args.sparkProperties) { + sparkConf.setIfMissing(k, v) + } + + // Ignore invalid spark.driver.host in cluster modes. + if (deployMode == CLUSTER) { + sparkConf.remove(DRIVER_HOST_ADDRESS) + } + + // Resolve paths in certain spark properties + val pathConfigs = Seq( + JARS.key, + FILES.key, + ARCHIVES.key, + "spark.yarn.dist.files", + "spark.yarn.dist.archives", + "spark.yarn.dist.jars") + pathConfigs.foreach { config => + // Replace old URIs with resolved URIs, if they exist + sparkConf.getOption(config).foreach { oldValue => + sparkConf.set(config, Utils.resolveURIs(oldValue)) + } + } + + // Resolve and format python file paths properly before adding them to the PYTHONPATH. + // The resolving part is redundant in the case of --py-files, but necessary if the user + // explicitly sets `spark.submit.pyFiles` in his/her default properties file. + val pyFiles = sparkConf.get(SUBMIT_PYTHON_FILES) + val resolvedPyFiles = Utils.resolveURIs(pyFiles.mkString(",")) + val formattedPyFiles = if (deployMode != CLUSTER) { + PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + } else { + // Ignoring formatting python path in yarn and mesos cluster mode, these two modes + // support dealing with remote python files, they could distribute and add python files + // locally. + resolvedPyFiles + } + sparkConf.set(SUBMIT_PYTHON_FILES, formattedPyFiles.split(",").toSeq) + + (childArgs.toSeq, childClasspath.toSeq, sparkConf, childMainClass) + } + + // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with + // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos or Kubernetes + // mode, we must trick it into thinking we're YARN. + private def setRMPrincipal(sparkConf: SparkConf): Unit = { + val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName + val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" + logInfo(s"Setting ${key} to ${shortUserName}") + sparkConf.set(key, shortUserName) + } + + private def getSubmitClassLoader(sparkConf: SparkConf): MutableURLClassLoader = { + val loader = + if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { + new ChildFirstURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } else { + new MutableURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } + Thread.currentThread.setContextClassLoader(loader) + loader + } + + /** + * Run the main method of the child class using the submit arguments. + * + * This runs in two steps. First, we prepare the launch environment by setting up + * the appropriate classpath, system properties, and application arguments for + * running the child main class based on the cluster manager and the deploy mode. + * Second, we use this launch environment to invoke the main method of the child + * main class. + * + * Note that this main class will not be the one provided by the user if we're + * running cluster deploy mode or python applications. + */ + private def runMain(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { + val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args) + // Let the main class re-initialize the logging system once it starts. + if (uninitLog) { + Logging.uninitialize() + } + + if (args.verbose) { + logInfo(s"Main class:\n$childMainClass") + logInfo(s"Arguments:\n${childArgs.mkString("\n")}") + // sysProps may contain sensitive information, so redact before printing + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") + logInfo("\n") + } + val loader = getSubmitClassLoader(sparkConf) + for (jar <- childClasspath) { + addJarToClasspath(jar, loader) + } + + var mainClass: Class[_] = null + + try { + mainClass = Utils.classForName(childMainClass) + } catch { + case e: ClassNotFoundException => + logError(s"Failed to load class $childMainClass.") + if (childMainClass.contains("thriftserver")) { + logInfo(s"Failed to load main class $childMainClass.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") + } + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) + case e: NoClassDefFoundError => + logError(s"Failed to load $childMainClass: ${e.getMessage()}") + if (e.getMessage.contains("org/apache/hadoop/hive")) { + logInfo(s"Failed to load hive class.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") + } + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) + } + + val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { + mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication] + } else { + new JavaMainApplication(mainClass) + } + + @tailrec + def findCause(t: Throwable): Throwable = t match { + case e: UndeclaredThrowableException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: InvocationTargetException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: Throwable => + e + } + + try { + app.start(childArgs.toArray, sparkConf) + } catch { + case t: Throwable => + throw findCause(t) + } + } + + /** Throw a SparkException with the given error message. */ + private def error(msg: String): Unit = throw new SparkException(msg) + +} + + +/** + * This entry point is used by the launcher library to start in-process Spark applications. + */ +private[spark] object InProcessSparkSubmit { + + def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() + submit.doSubmit(args) + } + +} + +object SparkSubmit extends CommandLineUtils with Logging { + + // Cluster managers + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES + + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + + // Special primary resource names that represent shells rather than application jars. + private val SPARK_SHELL = "spark-shell" + private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" + + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" + + override def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() { + self => + + override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) { + override protected def logInfo(msg: => String): Unit = self.logInfo(msg) + + override protected def logWarning(msg: => String): Unit = self.logWarning(msg) + + override protected def logError(msg: => String): Unit = self.logError(msg) + } + } + + override protected def logInfo(msg: => String): Unit = printMessage(msg) + + override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + + override protected def logError(msg: => String): Unit = printMessage(s"Error: $msg") + + override def doSubmit(args: Array[String]): Unit = { + try { + super.doSubmit(args) + } catch { + case e: SparkUserAppException => + exitFn(e.exitCode) + } + } + + } + + submit.doSubmit(args) + } + + /** + * Return whether the given primary resource represents a user jar. + */ + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) + } + + /** + * Return whether the given primary resource represents a shell. + */ + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) + } + + /** + * Return whether the given main class represents a sql shell. + */ + private[deploy] def isSqlShell(mainClass: String): Boolean = { + mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" + } + + /** + * Return whether the given main class represents a thrift server. + */ + private def isThriftServer(mainClass: String): Boolean = { + mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + } + + /** + * Return whether the given primary resource requires running python. + */ + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL + } + + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && (res.endsWith(".R") || res.endsWith(".r")) || res == SPARKR_SHELL + } + + private[deploy] def isInternal(res: String): Boolean = { + res == SparkLauncher.NO_RESOURCE + } + +} + +/** Provides utility functions to be used inside SparkSubmit. */ +private[spark] object SparkSubmitUtils { + + // Exposed for testing + var printStream = SparkSubmit.printStream + + // Exposed for testing. + // These components are used to make the default exclusion rules for Spark dependencies. + // We need to specify each component explicitly, otherwise we miss + // spark-streaming utility components. Underscore is there to differentiate between + // spark-streaming_2.1x and spark-streaming-kafka-0-10-assembly_2.1x + val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "kvstore_", "launcher_", "mllib_", + "mllib-local_", "network-common_", "network-shuffle_", "repl_", "sketch_", "sql_", "streaming_", + "tags_", "unsafe_") + + /** + * Represents a Maven Coordinate + * @param groupId the groupId of the coordinate + * @param artifactId the artifactId of the coordinate + * @param version the version of the coordinate + */ + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } + + /** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ + def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { + coordinates.split(",").map { p => + val splits = p.replace("/", ":").split(":") + require(splits.length == 3, s"Provided Maven Coordinates must be in the form " + + s"'groupId:artifactId:version'. The coordinate provided is: $p") + require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " + + s"be whitespace. The groupId provided is: ${splits(0)}") + require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " + + s"be whitespace. The artifactId provided is: ${splits(1)}") + require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + + s"be whitespace. The version provided is: ${splits(2)}") + new MavenCoordinate(splits(0), splits(1), splits(2)) + } + } + + /** Path of the local Maven cache. */ + private[spark] def m2Path: File = { + if (Utils.isTesting) { + // test builds delete the maven cache, and this can cause flakiness + new File("dummy", ".m2" + File.separator + "repository") + } else { + new File(System.getProperty("user.home"), ".m2" + File.separator + "repository") + } + } + + /** + * Extracts maven coordinates from a comma-delimited string + * @param defaultIvyUserDir The default user path for Ivy + * @return A ChainResolver used by Ivy to search for and resolve dependencies. + */ + def createRepoResolvers(defaultIvyUserDir: File): ChainResolver = { + // We need a chain resolver if we want to check multiple repositories + val cr = new ChainResolver + cr.setName("spark-list") + + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + localM2.setRoot(m2Path.toURI.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + + val localIvy = new FileSystemResolver + val localIvyRoot = new File(defaultIvyUserDir, "local") + localIvy.setLocal(true) + localIvy.setRepository(new FileRepository(localIvyRoot)) + val ivyPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", "[revision]", + "ivys", "ivy.xml").mkString(File.separator) + localIvy.addIvyPattern(ivyPattern) + val artifactPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", + "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.addArtifactPattern(artifactPattern) + localIvy.setName("local-ivy-cache") + cr.add(localIvy) + + // the biblio resolver resolves POM declared dependencies + val br: IBiblioResolver = new IBiblioResolver + br.setM2compatible(true) + br.setUsepoms(true) + val defaultInternalRepo : Option[String] = sys.env.get("DEFAULT_ARTIFACT_REPOSITORY") + br.setRoot(defaultInternalRepo.getOrElse("https://repo1.maven.org/maven2/")) + br.setName("central") + cr.add(br) + + val sp: IBiblioResolver = new IBiblioResolver + sp.setM2compatible(true) + sp.setUsepoms(true) + sp.setRoot(sys.env.getOrElse( + "DEFAULT_ARTIFACT_REPOSITORY", "https://dl.bintray.com/spark-packages/maven")) + sp.setName("spark-packages") + cr.add(sp) + cr + } + + /** + * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath + * (will append to jars in SparkSubmit). + * @param artifacts Sequence of dependencies that were resolved and retrieved + * @param cacheDirectory directory where jars are cached + * @return a comma-delimited list of paths for the dependencies + */ + def resolveDependencyPaths( + artifacts: Array[AnyRef], + cacheDirectory: File): String = { + artifacts.map { artifactInfo => + val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId + val extraAttrs = artifactInfo.asInstanceOf[Artifact].getExtraAttributes + val classifier = if (extraAttrs.containsKey("classifier")) { + "-" + extraAttrs.get("classifier") + } else { + "" + } + cacheDirectory.getAbsolutePath + File.separator + + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}$classifier.jar" + }.mkString(",") + } + + /** Adds the given maven coordinates to Ivy's module descriptor. */ + def addDependenciesToIvy( + md: DefaultModuleDescriptor, + artifacts: Seq[MavenCoordinate], + ivyConfName: String): Unit = { + artifacts.foreach { mvn => + val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) + val dd = new DefaultDependencyDescriptor(ri, false, false) + dd.addDependencyConfiguration(ivyConfName, ivyConfName + "(runtime)") + // scalastyle:off println + printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println + md.addDependency(dd) + } + } + + /** Add exclusion rules for dependencies already included in the spark-assembly */ + def addExclusionRules( + ivySettings: IvySettings, + ivyConfName: String, + md: DefaultModuleDescriptor): Unit = { + // Add scala exclusion rule + md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) + + IVY_DEFAULT_EXCLUDES.foreach { comp => + md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, + ivyConfName)) + } + } + + /** + * Build Ivy Settings using options with default resolvers + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return An IvySettings object + */ + def buildIvySettings(remoteRepos: Option[String], ivyPath: Option[String]): IvySettings = { + val ivySettings: IvySettings = new IvySettings + processIvyPathArg(ivySettings, ivyPath) + + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + ivySettings.setDefaultIvyUserDir(new File("/tmp")) + // create the dependency resolvers + val repoResolver = createRepoResolvers(ivySettings.getDefaultIvyUserDir) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + processRemoteRepoArg(ivySettings, remoteRepos) + ivySettings + } + + /** + * Load Ivy settings from a given filename, using supplied resolvers + * @param settingsFile Path to Ivy settings file + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return An IvySettings object + */ + def loadIvySettings( + settingsFile: String, + remoteRepos: Option[String], + ivyPath: Option[String]): IvySettings = { + val file = new File(settingsFile) + require(file.exists(), s"Ivy settings file $file does not exist") + require(file.isFile(), s"Ivy settings file $file is not a normal file") + val ivySettings: IvySettings = new IvySettings + try { + ivySettings.load(file) + } catch { + case e @ (_: IOException | _: ParseException) => + throw new SparkException(s"Failed when loading Ivy settings from $settingsFile", e) + } + processIvyPathArg(ivySettings, ivyPath) + processRemoteRepoArg(ivySettings, remoteRepos) + ivySettings + } + + /* Set ivy settings for location of cache, if option is supplied */ + private def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { + ivyPath.filterNot(_.trim.isEmpty).foreach { alternateIvyDir => + ivySettings.setDefaultIvyUserDir(new File(alternateIvyDir)) + ivySettings.setDefaultCache(new File(alternateIvyDir, "cache")) + } + } + + /* Add any optional additional remote repositories */ + private def processRemoteRepoArg(ivySettings: IvySettings, remoteRepos: Option[String]): Unit = { + remoteRepos.filterNot(_.trim.isEmpty).map(_.split(",")).foreach { repositoryList => + val cr = new ChainResolver + cr.setName("user-list") + + // add current default resolver, if any + Option(ivySettings.getDefaultResolver).foreach(cr.add) + + // add additional repositories, last resolution in chain takes precedence + repositoryList.zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + // scalastyle:off println + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println + } + + ivySettings.addResolver(cr) + ivySettings.setDefaultResolver(cr.getName) + } + } + + /** A nice function to use in tests as well. Values are dummy strings. */ + def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + // Include UUID in module name, so multiple clients resolving maven coordinate at the same time + // do not modify the same resolution file concurrently. + ModuleRevisionId.newInstance("org.apache.spark", + s"spark-submit-parent-${UUID.randomUUID.toString}", + "1.0")) + + /** + * Clear ivy resolution from current launch. The resolution file is usually at + * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml, + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties. + * Since each launch will have its own resolution files created, delete them after + * each resolution to prevent accumulation of these files in the ivy cache dir. + */ + private def clearIvyResolutionFiles( + mdId: ModuleRevisionId, + ivySettings: IvySettings, + ivyConfName: String): Unit = { + val currentResolutionFiles = Seq( + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" + ) + currentResolutionFiles.foreach { filename => + new File(ivySettings.getDefaultCache, filename).delete() + } + } + + /** + * Resolves any dependencies that were supplied through maven coordinates + * @param coordinates Comma-delimited string of maven coordinates + * @param ivySettings An IvySettings containing resolvers to use + * @param exclusions Exclusions to apply when resolving transitive dependencies + * @return The comma-delimited path to the jars of the given maven artifacts including their + * transitive dependencies + */ + def resolveMavenCoordinates( + coordinates: String, + ivySettings: IvySettings, + exclusions: Seq[String] = Nil, + isTest: Boolean = false): String = { + if (coordinates == null || coordinates.trim.isEmpty) { + "" + } else { + val sysOut = System.out + // Default configuration name for ivy + val ivyConfName = "default" + + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + + md.setDefaultConf(ivyConfName) + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val packagesDirectory: File = new File(ivySettings.getDefaultIvyUserDir, "jars") + // scalastyle:off println + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } + + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) + exclusions.foreach { e => + md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) + } + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision](-[classifier]).[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) + clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings, ivyConfName) + } + } + } + + private[deploy] def createExclusion( + coords: String, + ivySettings: IvySettings, + ivyConfName: String): ExcludeRule = { + val c = extractMavenCoordinates(coords)(0) + val id = new ArtifactId(new ModuleId(c.groupId, c.artifactId), "*", "*", "*") + val rule = new DefaultExcludeRule(id, ivySettings.getMatcher("glob"), null) + rule.addConfiguration(ivyConfName) + rule + } + + def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => throw new SparkException(s"Spark config without '=': $pair") + } + } + + private[deploy] def getSubmitOperations(master: String): SparkSubmitOperation = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoaders = + ServiceLoader.load(classOf[SparkSubmitOperation], loader) + .asScala + .filter(_.supports(master)) + + serviceLoaders.size match { + case x if x > 1 => + throw new SparkException(s"Multiple($x) external SparkSubmitOperations " + + s"clients registered for master url ${master}.") + case 1 => serviceLoaders.headOption.get + case _ => + throw new IllegalArgumentException(s"No external SparkSubmitOperations " + + s"clients found for master url: '$master'") + } + } +} + +/** + * Provides an indirection layer for passing arguments as system properties or flags to + * the user's driver program or to downstream launcher tools. + */ +private case class OptionAssigner( + value: String, + clusterManager: Int, + deployMode: Int, + clOption: String = null, + confKey: String = null, + mergeFn: Option[(String, String) => String] = None) + +private[spark] trait SparkSubmitOperation { + + def kill(submissionId: String, conf: SparkConf): Unit + + def printSubmissionStatus(submissionId: String, conf: SparkConf): Unit + + def supports(master: String): Boolean +} -- Gitee