Spark 1.6推出了RPCEnv,RPCEndpoint,RPCEndpointRef为核心的新型架构下的RPC通信方式。 早期版本当中,有netty和Akka两种实现方式。但是从最新的2.11代码来看,Akka已经找不到了。关于netty的更多知识,可以查看之前的文章。
RpcEndpoint和RpcEndpointRef有一个管理者:RpcEnv。
RpcEnv是一个RpcEndpoints用于处理消息的环境,管理着整个RpcEndpoint的生命周期:(1)根据name或uri注册endpoints;(2)管理各种消息的处理;(3)停止endpoints。RpcEnv必须通过工厂类RpcEnvFactory创建。
RpcEnv里面有setupEndpoint方法, RpcEndpoint和RpcEndpointRef向RpcEnv进行注册。
客户端通过RpcEndpointRef发消息,首先通过RpcEnv来处理这个消息,找到这个消息具体发给谁,然后路由给RpcEndpoint实体。现在有两种方式,一种是AkkaRpcEnv,另一种是NettyRpcEnv。Spark默认使用更加高效的NettyRpcEnv。
对于Rpc捕获到的异常消息,RpcEnv将会用RpcCallContext.sendFailure将失败消息发送给发送者,或者将没有发送者,‘NotSerializableException’等记录到日志中。同时,RpcEnv也提供了根据name或uri获取RpcEndpointRef的方法。
下面,我们深入Spark源代码,看看上面3个组件,究竟是如何工作的。
首先,需要构建RPC通讯环境,这一点,应该是在SparkContext初始化的时候就应该完成的,我们看看代码是不是这样的:
在SparkContext当中,有一个私有属性,定义如下:
private var _env: SparkEnv = _
我们找到给它赋值的地方,位于 SparkContext 的初始化区域:
_env = createSparkEnv(_conf, isLocal, listenerBus)
private[spark] def createSparkEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus): SparkEnv = {
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master))
}
private[spark] def createDriverEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus,
numCores: Int,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
val port = conf.get("spark.driver.port").toInt
create(
conf,
SparkContext.DRIVER_IDENTIFIER,
hostname,
port,
isDriver = true,
isLocal = isLocal,
numUsableCores = numCores,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
}
下面这个函数当中,完成了很多工作,整个SparkEnv的构建工作,都在这里完成,我们找一下我们需要的东西:
private def create(
conf: SparkConf,
executorId: String,
hostname: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean,
numUsableCores: Int,
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
}
val securityManager = new SecurityManager(conf)
val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, hostname, port, conf, securityManager,
clientMode = !isDriver)
// Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
// In the non-driver case, the RPC env's address may be null since it may not be listening
// for incoming connections.
if (isDriver) {
conf.set("spark.driver.port", rpcEnv.address.port.toString)
} else if (rpcEnv.address != null) {
conf.set("spark.executor.port", rpcEnv.address.port.toString)
logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}")
}
// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
}
// Create an instance of the class named by the given SparkConf property, or defaultClassName
// if the property is not set, possibly initializing it with our conf
def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
instantiateClass[T](conf.get(propertyName, defaultClassName))
}
val serializer = instantiateClassFromConf[Serializer](
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")
val serializerManager = new SerializerManager(serializer, conf)
val closureSerializer = new JavaSerializer(conf)
def registerOrLookupEndpoint(
name: String, endpointCreator: => RpcEndpoint):
RpcEndpointRef = {
if (isDriver) {
logInfo("Registering " + name)
rpcEnv.setupEndpoint(name, endpointCreator)
} else {
RpcUtils.makeDriverRef(name, conf, rpcEnv)
}
}
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
// Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint
// requires the MapOutputTracker itself
mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
"tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false)
val memoryManager: MemoryManager =
if (useLegacyMemoryManager) {
new StaticMemoryManager(conf, numUsableCores)
} else {
UnifiedMemoryManager(conf, numUsableCores)
}
val blockTransferService =
new NettyBlockTransferService(conf, securityManager, hostname, numUsableCores)
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_ENDPOINT_NAME,
new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
conf, isDriver)
// NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
blockTransferService, securityManager, numUsableCores)
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
// We need to set the executor ID before the MetricsSystem is created because sources and
// sinks specified in the metrics configuration file will want to incorporate this executor's
// ID into the metrics they report.
conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
}
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
new OutputCommitCoordinator(conf, isDriver)
}
val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator",
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
val envInstance = new SparkEnv(
executorId,
rpcEnv,
serializer,
closureSerializer,
serializerManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
blockManager,
securityManager,
metricsSystem,
memoryManager,
outputCommitCoordinator,
conf)
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
// called, and we only need to do it for driver. Because driver may run as a service, and if we
// don't delete this tmp dir when sc is stopped, then will create too many tmp dirs.
if (isDriver) {
val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
envInstance.driverTmpDir = Some(sparkFilesDir)
}
envInstance
}
再看一下RpcEnv.create 的实现:
def create(
name: String,
host: String,
port: Int,
conf: SparkConf,
securityManager: SecurityManager,
clientMode: Boolean = false): RpcEnv = {
val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
new NettyRpcEnvFactory().create(config)
}
到目前为止,我们搞清楚了2件事情, Netty 貌似是目前Spark RPC通讯的唯一选择,RpcENV保存在sc._env(SparkEnv的一个实例)当中。
到目前为止,RpcEnv的创建工作已经完成了。那么RPC通讯的另外2个元素,RpcEndPoint和RpcEndPointRef何时登场呢?
之前我们对RDD的count过程进行解析的时候,在TaskSchedulerImpl当中找到过下面这段代码:
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
"and have sufficient resources")
} else {
this.cancel()
}
}
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
}
hasReceivedTask = true
}
backend.reviveOffers()
}
而 reviveOffers 的实现在CoarseGrainedSchedulerBackend当中:
override def reviveOffers() {
driverEndpoint.send(ReviveOffers)
}
我们还是要先看看 driverEndpoint 是怎么被构建出来的:
首先是声明部分:var driverEndpoint: RpcEndpointRef = null
override def start() {
val properties = new ArrayBuffer[(String, String)]
for ((key, value) <- scheduler.sc.conf.getAll) {
if (key.startsWith("spark.")) {
properties += ((key, value))
}
}
// TODO (prashant) send conf instead of properties
driverEndpoint = createDriverEndpointRef(properties)
}
protected def createDriverEndpointRef(
properties: ArrayBuffer[(String, String)]): RpcEndpointRef = {
rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))
}
protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
new DriverEndpoint(rpcEnv, properties)
}
上面红色字体部分分为几步:1、创建一个Driver End Point,2、将此EndPoint注册到RpcEnv当中,3、获得此EndPoint的Ref,也就是 driverEndpoint 。
我们再深究一下这个 setupEndpoint 当中究竟干了些什么:
NettyRpcEnv.scala
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
dispatcher.registerRpcEndpoint(name, endpoint)
}
Dispatcher.scala
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val addr = RpcEndpointAddress(nettyEnv.address, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
throw new IllegalStateException("RpcEnv has been stopped")
}
if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
}
val data = endpoints.get(name)
endpointRefs.put(data.endpoint, data.ref)
receivers.offer(data) // for the OnStart message
}
endpointRef
}
分为几步:1、根据end point, 生成end point ref,2、强end point和 end point ref分别放到一个ConcurrentMap( endpoints 和 endpointRefs 的类型均为 ConcurrentMap )当中。
到此为止,end point和end point ref也已经登场,并且已经各就位了。
回到刚才的地方,我们再看看send方法做了些什么事情:
override def reviveOffers() {
driverEndpoint.send(ReviveOffers)
}
NettyRpcEnv的内部类,NettyRpcEndpointRef
override def send(message: Any): Unit = {
require(message != null, "Message is null")
nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
}
NettyRpcEnv:
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
try {
dispatcher.postOneWayMessage(message)
} catch {
case e: RpcEnvStoppedException => logWarning(e.getMessage)
}
} else {
// Message to a remote RPC endpoint.
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
}
}
我们首先需要搞清楚, remoteAddr == address 是怎么回事?
经过翻阅代码,当前场景,只能进到 dispatcher.postOneWayMessage(message) 当中。何时会进行一个remote的rpc调用,我们再做研究。
Dispatcher.scala
def postOneWayMessage(message: RequestMessage): Unit = {
postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
(e) => throw e)
}
private def postMessage(
endpointName: String,
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit = {
val error = synchronized {
val data = endpoints.get(endpointName)
if (stopped) {
Some(new RpcEnvStoppedException())
} else if (data == null) {
Some(new SparkException(s"Could not find $endpointName."))
} else {
data.inbox.post(message)
receivers.offer(data)
None
}
}
// We don't need to call `onStop` in the `synchronized` block
error.foreach(callbackIfStopped)
}
在这里,把message post到了这个data.inbox当中,我们猜测,应该会有一个或一组线程来监听 data.inbox 的变化,那么我们找一找 data.inbox,找到了下面这段代码:
private val threadpool: ThreadPoolExecutor = {
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, Runtime.getRuntime.availableProcessors()))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
非常棒,我们再看看 data.inbox.process(Dispatcher.this) 的实现在什么地方,Inbox.scala
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
if (!enableConcurrent && numActiveThreads != 0) {
return
}
message = messages.poll()
if (message != null) {
numActiveThreads += 1
} else {
return
}
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
case NonFatal(e) =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
throw e
}
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
assert(activeThreads == 1,
s"There should be only a single active thread but found $activeThreads threads.")
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
assert(isEmpty, "OnStop should be the last message")
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
// "enableConcurrent" will be set to false after `onStop` is called, so we should check it
// every time.
if (!enableConcurrent && numActiveThreads != 1) {
// If we are not the only one worker, exit
numActiveThreads -= 1
return
}
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
到目前为止,我们基本清楚了,在TaskSchedulerImpl submitTasks之后,会有另外一个线程,收到这个消息,找到这个代码不太难,CoarseGrainedSchedulerBackend的内部类,DriverEndpoint:
override def receive: PartialFunction[Any, Unit] = {
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.freeCores += scheduler.CPUS_PER_TASK
makeOffers(executorId)
case None =>
// Ignoring the update since we don't know about the executor.
logWarning(s"Ignored task status update ($taskId state $state) " +
s"from unknown executor with ID $executorId")
}
}
case ReviveOffers =>
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
case None =>
// Ignoring the task kill since the executor is not registered.
logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
}
}
到此为止,Rpc通讯的本地部分其实已经结束了。我们也基本搞清楚了RpcEnv,RpcEndPoint,RpcEndPointRef在整个流程当中的作用。
但是,到目前为止,我们还是在driver当中,还没有真正的进入spark集群。