Linux C++ 内存分配 SIGKILL 问题

背景

在代码里看到发现某对象有个 buffer 每次会分配较大内存(定义一个大的 vector),担心如果对象太多会导致 vector 内存分配失败,于是自己写了个简单测试看能不能捕获 bad_alloc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <limits.h>
#include <iostream>
#include <vector>
using namespace std;

constexpr int size = 1024 * 1024; // 1 MiB

int main() {
int i = 0;
try {
for (; i < INT_MAX - 1; i++) {
new vector<int>(size);
}
cout << "Success" << endl;
} catch (const std::bad_alloc& e) {
cout << "Failed at " << i << "th allocate" << endl;
}
return 0;
}

比较遗憾的是,运行直接显示 Killed 崩溃,在 Linux 上也就是收到了 SIGKILL 信号。这让我想起来当初压测某 C++ 客户端时,对于 500 个分区的 topic,在限制内存为 1 GB 的容器中运行 C++ 客户端,很多都会直接 Killed,日志里也没异常信息。

内存分配

C 标准库的内存分配是使用 malloc(以及 calloc/realloc,这三者类似,这里不讲区别),它的使用很简单,分配失败就返回 NULL,见 man page:

1
2
3
4
The malloc() function allocates size bytes and returns a pointer to
the allocated memory. The memory is not initialized. If size is 0,
then malloc() returns either NULL, or a unique pointer value that can
later be successfully passed to free().

C++ 的 new 则是分为两步,首先是调用 operator new 分配内存,然后就地构造(placement new),对于基本类型就是内存拷贝,对于类而言调用构造函数。

一般而言分配内存是前一步,大概是直接包的 malloc,相比而言,C++ 可以重写 operator new 来执行自定义内存分配策略,虽然 C 替代 malloc 也可以,但不一定通用。C++ new 失败不会返回空指针,而是会抛出 std::bad_alloc 异常。

STL 内存分配则又是一回事,为了能让用户能更精确掌控内存分配(而不是受到 new 的制约),它提供了一个 std::allocator<T> 类模板进行默认的内存分配,而用户可以自行实现接口作为自己的内存分配策略。当然,是不是包的 operator new 我也没去细看。

怀疑是 STL allocator 的问题,于是这里用 new 来替代 STL allocator 看看,由于 vector 一般实现是 3 个指针,所以这里也模拟了实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <limits.h>
#include <iostream>
using namespace std;

constexpr int size = 1024 * 1024; // 1 MiB

struct VectorInt {
int* start;
int* finish;
int* end;
};

int main() {
int i = 0;
try {
for (; ; i++) {
auto p = new VectorInt;
p->start = new int[size];
}
cout << "Success" << endl;
} catch (const std::bad_alloc& e) {
cout << "Failed at " << i << "th allocate" << endl;
}
return 0;
}

结果还是收到 SIGKILL 了。把 new 改成 malloc 并检查返回值是否为空,结果也是一样。

malloc 真的会返回 NULL 吗?

虽然答案显然是会,毕竟如果不会,那肯定不会是我第一个发现。所以我简单试了下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

int main(int argc, char* argv[]) {
for (int i = 0; i < 1024 * 1024 * 1024; i++) {
int* p = (int*)malloc(sizeof(int));
if (!p) {
printf("[%d] malloc failed: %s\n", i, strerror(errno));
return 1;
}
}
return 0;
}

还是收到 SIGKILL 了。这里我想起了之前看《STL 源码剖析》时,SGI STL 内存分配器对于小内存分配会进行优化,我想 malloc 应该也会,可能是我单次分配的内存(sizeof(int),4 字节)太小了,于是我开始尝试增加单次分配的字节数来观察现象。

  • 4 B:Killed

  • 4 KiB:Killed

  • 4 MiB:Killed

  • 20 MiB:Killed

  • 80 MiB:Killed

  • 1 GiB:没有被 Killed,运行多次的结果如下:

    1
    2
    3
    4
    5
    [131070] malloc failed: Cannot allocate memory
    [131071] malloc failed: Cannot allocate memory
    [131070] malloc failed: Cannot allocate memory
    [131071] malloc failed: Cannot allocate memory
    [131071] malloc failed: Cannot allocate memory

strace 重新运行查看系统调用:

1
2
3
4
5
6
7
8
9
10
brk(0x27a60149a000)                     = 0x27a5c149a000
mmap(NULL, 1073876992, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = -1 ENOMEM (Cannot allocate memory)
mmap(NULL, 134217728, PROT_NONE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_NORESERVE, -1, 0) = 0x3fe58dc4e000
munmap(0x3fe58dc4e000, 37429248) = 0
munmap(0x3fe594000000, 29679616) = 0
mprotect(0x3fe590000000, 135168, PROT_READ|PROT_WRITE) = 0
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = -1 ENOMEM (Cannot allocate memory)
fstat(1, {st_mode=S_IFCHR|0620, st_rdev=makedev(136, 3), ...}) = 0
mmap(NULL, 4096, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fe5a6a28000
write(1, "[106135] malloc failed: Cannot a"..., 47) = 47

在此之前是大量的 brk,也就是说每次运行 malloc 其实都是调用 brk。而导致 malloc 出错的是 mmap,它设置了错误码 ENOMEM,也就是 malloc 的 man page 里提到的错误码,表示内存不足。

将单词内存分配大小改回 128 MiB 试试(512 MiB 仍然能正常终止),最后几次系统调用:

1
2
3
brk(0x181922339000)                     = 0x181922339000
brk(0x18192a339000 <ptrace(SYSCALL):No such process>
+++ killed by SIGKILL +++

发现并没有调用 mmap,而是持续调用 brk,直到收到 SIGKILL

brk 和 mmap 系统调用

从前一节可知,这里出问题的并不是 STL allocator,而是 malloc 本身。而之前一直在调用 brk。还是参考 man page,这里首先得弄清楚 program break 的概念,首先看下进程虚拟地址空间分布(最简单的模型):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
high address +--------------------+
| |
| |
+--------------------+
| stack |
+---------+----------+
| | |
| | |
| v |
| |
| |
| |
| ^ |
| | |
| | |
+---------+----------+
| heap |
+--------------------+
| uninitialized data |
| (bss) |
+--------------------+
| initialized data |
+--------------------+
| text |
low address +--------------------+

program break 就是 bss 段上面的第一个地址。brk 的函数签名:

1
2
3
4
#include <unistd.h>

int brk(void *addr);
void *sbrk(intptr_t increment);

brk 的作用是改变 program break 的位置为 addr。增加 addr 相当于申请空间。至于 sbrk 则是增加而非设置。

这里以一个最简单的内存申请为例:

1
2
int* p = (int*)malloc(1024 * 1024 * 10);
free(p);

对应系统调用:

1
2
3
brk(0)                                  = 0x2112000
brk(0x2144000) = 0x2144000
mmap(NULL, 10489856, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fb6338d2000

第一次调用是 brk(0),返回的不是 0 而是 0x2112000,也就是最初状态下 bss 段上面的第一个地址。后面移动到了新位置,增加了 program break,0x2144000 - 0x2112000 = 0x0032000,也就是 204800 字节,即 197 KiB,显然是不足 10 MiB。

然而接下来的 mmap 第二个参数,10489856 刚好是 10 MiB。brk 只是调整 program break,这只是虚拟地址空间的一个标记,如果要使用这块内存,需要用 mmap 进行映射。

1
2
3
4
5
#include <sys/mman.h>

void *mmap(void *addr, size_t length, int prot, int flags,
int fd, off_t offset);
int munmap(void *addr, size_t length);
  • addr:为 NULL 则由内核选择首地址(满足页对齐),即使不为 NULL,也只是内核根据这个值作为提示信息来选择。
  • length:映射的字节数,实际上会被提升至分页大小(sysconf(_SC_PAGE_SIZE))的整数倍。
  • prot:保护等级,为 PROT_NONE(无法访问)或者 PROT_READ(可读)/PROT_WRITE(可写)/PROT_EXEC(可执行)的组合。如果对这部分内存映射违反了保护等级,会产生 SIGSEGV 信号,也就是常见的段错误(segmentation fault)。
  • flagsMAP_PRIVATE(私有映射)或者 MAP_SHARED(共享映射),前者是对其他进程不可见,后者是可见的,也就是共享内存。

最终内存会映射到从文件描述符 fd 对应的文件的偏移量 offset 处开始,注意 offset 也必须是分页大小的倍数。如果 fd 为 -1,则不映射到文件。

回顾之前的调用:

1
mmap(NULL, 10489856, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fb6338d2000

也就是将物理内存的 10 MiB 给映射到堆的虚拟内存空间。

至于 munmap 即删除映射,addr 为对应 mmap 的返回值。

malloc 小内存为何不返回 NULL?

刚才了解了 brkmmap 结合前文内容,我们可知只有 mmap 进行内存映射失败(比如物理内存不够)malloc 才会返回 NULL 并将 errno 置为 ENOMEM。但是对于频繁申请小内存的情况,则是无限调用 brk 导致内存崩溃,并没有出现 mmap。回到第一节的代码(无限创建 1 MiB 的 vector),用 strace 观察的结果是无限调用 mmap 导致内存崩溃:

1
2
3
4
mmap(NULL, 4198400, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7f3fdead1000
/* ... */
mmap(NULL, 4198400, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7f3e9ddcf000
+++ killed by SIGKILL +++

可见每次 mmap 长度上限也就 4198400(差不多 4 MiB)。

而假如 vector 大小改成 256 MiB,那么对应 strace 信息为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
brk(0)                                  = 0xcbb000
brk(0xced000) = 0xced000
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fdda8485000
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fdd68484000
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fdd28483000
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fdce8482000
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7fdca8481000
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = -1 ENOMEM (Cannot allocate memory)
brk(0x40ced000) = 0xced000
mmap(NULL, 1073876992, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = -1 ENOMEM (Cannot allocate memory)
mmap(NULL, 134217728, PROT_NONE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_NORESERVE, -1, 0) = 0x7fdca0481000
munmap(0x7fdca0481000, 62386176) = 0
munmap(0x7fdca8000000, 4722688) = 0
mprotect(0x7fdca4000000, 135168, PROT_READ|PROT_WRITE) = 0
mmap(NULL, 1073745920, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = -1 ENOMEM (Cannot allocate memory)

每次 mmap 长度是 1073745920,差不多 1 GiB。这么看,其实根本原因在于 malloc 实现(调用 brkmmap),因为 malloc 允许较为高效的小内存分配(比如它不会因为 N 次 int 的分配就调用 N 次 brk)。

其实在 malloc 的 man page 上面就可以看到原因,以下内容摘自 NOTE 部分。

默认情况下,Linux 遵循乐观的内存分配策略,也就是 malloc 返回非 NULL 时,无法保证内存确实可用,如果系统 OOM 了,进程会被 OOM killer 干掉,也就是 OOM killer 发送 SIGKILL 信号。

通常 malloc() 从堆上分配内存,并使用 sbrk() 调整堆的大小。当分配的内存块字节数大于 MMAP_THRESHOLD 时,glibc 的 malloc() 实现使用 mmap() 进行私有匿名映射来分配内存。MMAP_THRESHOLD 默认 128 kB,但是可以通过 mallopt() 进行调整。

这也说明了,为啥分配大容量 vector 时是一直调用 mmap,而分配 int 时则是一直调用 brk

mmap 的大小也有区别,每次映射 4198400 字节时,最后引向的是 SIGKILL,而每次映射 1073745920 字节时,最后得到的是错误码 ENOMEM。从上述文档可知,前者应该是触发 OOM killer 了,也就是映射了较小的内存,但不一定保证可用就返回了。也就是所谓的 乐观内存分配策略 导致的。

sigkill while allocating memory in c 这个讨论帖有讲,简单说就是:

  • newmalloc 可能从内核拿到了一个不合法的地址,即使没有足够的内存,因为:
    1. 内核直到第一次访问时才分配地址
    2. 如果所有的 overcommited 内存都被使用,操作系统只能杀死其中的一个进程(也就是 OOM killer)。

OOM killer 可参考: https://linux-mm.org/OOM_Killer

简单说就是牺牲一个或多个进程,以便在其他进程故障时为系统释放进程。

那 overcommited 内存呢?它是一个内存分配策略,其值记录在 /proc/sys/vm/overcommit_memory 文件。所谓 overcommit,就是过度提交,其实也就是刚才说的,内核直到第一次访问时才分配地址,否则即使请求的内存超过限制,只要不去访问,内核仍然当作可用。

how to check if malloc overcommits memory 给出了几种解决方法,方法 2 我是开启了 --oom-kill-disable 选项启动 docker,但还是一样。可能是哪里操作不对。方法 3(在 malloc 后进行写入)可能会花很长时间来分配较大内存,但是仍然可能被 kill。

总结

Linux 上,在某些情况下 malloc 即使返回非 NULL 值,由于内核的 overcommit 内存分配策略,也有可能导致这片内存不是可用的,这种内存泄漏累积的结果就是导致进程被 OOM killer 干掉(发送 SIGKILL 信号,无法被捕获)。

单次 malloc 申请较大的内存可以规避这点,因为即使有 overcommit 策略,也不可能超过可用太多。

总之,如果是在限制内存(比如 docker)的环境下运行 C++ 程序,内存分配失败不一定会以 std::bad_alloc 结束,此时会拿不到异常栈信息。虽然用 dmesg 能够找到 OOM 记录,但对调试的帮助仍然有限。这种时候,根本性的解决方案还是自定义内存分配器。

vtable 和 typeinfo 符号丢失排查笔记

[toc]

前言

简单说原来的结构是有一个类 Foo,并且重载了流运算符 operator<<,在另一个类 Caller 中直接 << 打印出 Foo 对象。

现在由于功能扩充,Foo 可能有多种变体,因此需要进行以下重构:

1
2
3
4
FooBase
-> Foo
-> AnotherFoo
-> ...

然后 Caller 内部存放的对象从 std::unique_ptr<Foo> 改成了 std::unique_ptr<FooBase>

基本知识回顾:流运算符重载

对于类 Foo,流运算符重载是一个全局函数(注意,不是类成员函数):

1
2
3
4
std::ostream& operator<<(std::ostream& os, const Foo& foo) {
// os << foo 的内部字段
return os;
}

为了能直接访问 Foo 的内部字段,一般会将该重载函数声明为类 Foo 的友元函数:

1
2
3
4
class Foo {
friend std::ostream& operator<<(std::ostream&, const Foo&);
// ...
};

这里要避免一个误区,那就是将 operator<< 作为类成员函数的话,比如:

1
2
3
4
5
6
7
class Foo {
public:
ostream& operator<<(std::ostream& os) const {
// os << 内部字段
return os;
}
};

对应的调用是这样:

1
2
3
Foo foo;
foo.operator<<(std::cout); // 完整调用方式
foo << std::cout; // 简略版调用方式

std::cout << os 则是由全局函数来重载的,也就是是标准库的 std::ostream 类的 operator<< 方法调用的是 std::ostream& operator<<(std::ostream&, const T&) 函数来实现对任意类型 T 的对象进行输出。

继承体系的解决方式以及 vtable 信息缺失问题

对于继承体系,我们想要的其实是下面这样:

1
2
3
4
5
struct Base { /* ... */ };
struct Derived { /* ... */ };

auto base = new Derived;
std::cout << *base << std::endl; // 调用 Derived 相关的 operator<<,且只暴露 Base 接口

可以用间接的方式来实现:

1
2
3
4
5
6
7
8
9
10
class Base {
friend std::ostream& operator<<(std::ostream& os, const Base& base) {
return base.print(os);
}
public:
virtual std::ostream& print(std::ostream& os) const {
// os << 内部字段
return os;
}
};

派生类只要重载 print 方法即可。

然而我这么干了,编译 OK,但是链接时出错:

undefined reference to `vtable for 【基类名】

undefined reference to `typeinfo for 【基类名】

问题排查

一开始只有这个信息,所以不好排除,我对比了半天(overridevirtual 关键字数量对比),基类的虚函数我在派生类全都实现了啊。


PS:用 override 关键字可以很大程度避免重载函数名写错的情况,比如:

1
2
3
4
5
6
struct Base {
virtual void doSomething() {}
};
struct Derived : Base {
void doSOmeting() override { /* ... */ }
};

此时不小心写错了虚函数名字,用 override 关键字就直接能在编译时提示错误:

error: ‘void Derived::doSOmeting()’ marked ‘override’, but does not override

如果不加 override 关键字,编译就不会出错,但是调用的是基类 BasedoSomething() 方法,派生类并没有重写该方法。导致这种低级错误得等到运行期去排查。


回到问题,这里我就有点束手无策了,总感觉自己有些基本知识弄错了(实际上并没有)。首先想了下是不是我虚析构函数的问题(因为除了虚析构函数外其他虚函数都是纯虚函数):

1
2
3
4
class FooBase {
public:
virtual ~FooBase() {}
};

然后改成了头文件声明,源文件定义:

1
2
3
4
5
// foo_base.h
class FooBase {
public:
virtual ~FooBase();
};
1
2
// foo_base.cc
FooBase::~FooBase() {}

一个有趣的现象,虽然还是报错,但是报错信息变了,有了更具体的信息:

undefined reference to `FooBase::print(std::ostream&) const’

到这里我才回过头来审视 print 方法。不过为了验证观点,首先把这个虚函数给删掉,链接成功,证明了这个观点。再回过头来看,我并没有将其实现为纯虚函数,而是:

1
virtual void print(std::ostream&) const;

只是声明,因此在基类的编译单元缺少其实现信息。看起来很简单的错误,但是在压力和恐慌之下,人的眼睛是不可相信的,至少我的肉眼看到的,这个分号前面就有个 = 0

追根溯源

能定位到这个错误,有一定程度上是因为我恰好把虚函数放到源文件中,有了新的报错信息。那么区别在哪呢?这里复现一下。

复现

给出以下源文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// base.h
#pragma once

#include <iostream>

struct Base {
virtual ~Base() {}

virtual std::ostream& print(std::ostream& os) const;

friend std::ostream& operator<<(std::ostream& os, const Base& base) {
return base.print(os);
}
};
1
2
3
4
5
6
7
8
// derived.h
#pragma once

#include "base.h"

struct Derived : Base {
std::ostream& print(std::ostream& os) const override;
};
1
2
3
4
5
6
// derived.cc
#include "derived.h"

std::ostream& Derived::print(std::ostream& os) const {
return (os << "Derived");
}

编译成动态库 libbase.so

1
$ g++ -o libbase.so derived.cc -std=c++11 -fPIC -shared

因为是动态库,开启了 -fPIC 选项,即 position-independent code,位置无关的代码,也就是函数(符号)的实现暂时可以不定位到具体的位置,而是在和其他编译单元链接时再定位。

这里给出调用代码:

1
2
3
4
5
6
7
8
9
// main.cc
#include <memory>
#include "derived.h"

int main(int argc, char* argv[]) {
std::unique_ptr<Base> base(new Derived);
std::cout << *base << std::endl;
return 0;
}

编译:

1
2
3
4
5
6
$ g++ main.cc -std=c++11 -L. -lfoo
/tmp/cc39RCrN.o: In function `Base::Base()':
main.cc:(.text._ZN4BaseC2Ev[_ZN4BaseC5Ev]+0x9): undefined reference to `vtable for Base'
/tmp/cc39RCrN.o: In function `Derived::Derived()':
main.cc:(.text._ZN7DerivedC2Ev[_ZN7DerivedC5Ev]+0x19): undefined reference to `vtable for Derived'
collect2: error: ld returned 1 exit status

查看符号表

使用 nm 查看符号表:

1
2
3
4
5
6
7
8
9
10
11
12
13
$ nm libbase.so | egrep "(Base|Derived)"
0000000000000cda W _ZN4BaseD0Ev
0000000000000ca4 W _ZN4BaseD1Ev
0000000000000ca4 W _ZN4BaseD2Ev
0000000000000d42 W _ZN7DerivedD0Ev
0000000000000d00 W _ZN7DerivedD1Ev
0000000000000d00 W _ZN7DerivedD2Ev
0000000000000c20 T _ZNK7Derived5printERSo
U _ZTI4Base
0000000000201058 V _ZTI7Derived
0000000000000dc8 V _ZTS7Derived
U _ZTV4Base
0000000000201030 V _ZTV7Derived

由于 name mangling,上面的比较难辨认。注意第二列的符号,从 nm 帮助手册 可知:

  • W:没有标记成弱对象的弱(Weak)符号,当弱符号链接到普通符号时不会报错,当弱符号被链接且该符号未定义时,该符号的值用一种系统特定的方式决定,不会报错。某些系统上,大写的 W 代表默认值被指定。
  • T:符号在文本(代码)段。
  • U:符号未定义(Undefined )。
  • V弱对象,其余说明同 W

PS:即使使用 delete 禁止拷贝构造函数和拷贝赋值运算符,符号表仍然不变。

如果将 Base::print 改成纯虚函数呢?

1
virtual std::ostream& print(std::ostream& os) const = 0;

符号表变成了:

1
2
3
4
5
6
7
8
9
10
11
12
13
0000000000000dda W _ZN4BaseD0Ev
0000000000000da4 W _ZN4BaseD1Ev
0000000000000da4 W _ZN4BaseD2Ev
0000000000000e42 W _ZN7DerivedD0Ev
0000000000000e00 W _ZN7DerivedD1Ev
0000000000000e00 W _ZN7DerivedD2Ev
0000000000000d20 T _ZNK7Derived5printERSo
00000000002010b8 V _ZTI4Base
00000000002010a0 V _ZTI7Derived
0000000000000ed1 V _ZTS4Base
0000000000000ec8 V _ZTS7Derived
0000000000201078 V _ZTV4Base
0000000000201050 V _ZTV7Derived

这里列出区别:

  • _ZTI4Base_ZTV4BaseU(未定义)变成了 V,也就是弱对象。
  • 多了个弱对象 _ZTS4Base

然后将析构函数单独抽离到 base.cc 来实现,重新编译动态库:

1
$ g++ -o libbase.so base.cc derived.cc -std=c++11 -fPIC -shared

符号表变成了:

1
2
3
4
5
6
7
8
9
10
11
12
13
0000000000000d66 T _ZN4BaseD0Ev
0000000000000d30 T _ZN4BaseD1Ev
0000000000000d30 T _ZN4BaseD2Ev
0000000000000eb2 W _ZN7DerivedD0Ev
0000000000000e70 W _ZN7DerivedD1Ev
0000000000000e70 W _ZN7DerivedD2Ev
0000000000000dec T _ZNK7Derived5printERSo
0000000000201138 V _ZTI4Base
0000000000201170 V _ZTI7Derived
0000000000000f29 V _ZTS4Base
0000000000000f38 V _ZTS7Derived
0000000000201110 V _ZTV4Base
0000000000201148 V _ZTV7Derived

主要区别:

  • 三个 _ZN4BaseD<i>Ev(i 是0,1,2)从 W(弱符号)变成了 T(文本段)。

而重新将 print 改成未定义的函数后,符号表变成了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
0000000000000d96 T _ZN4BaseD0Ev
0000000000000d60 T _ZN4BaseD1Ev
0000000000000d60 T _ZN4BaseD2Ev
0000000000000ee2 W _ZN7DerivedD0Ev
0000000000000ea0 W _ZN7DerivedD1Ev
0000000000000ea0 W _ZN7DerivedD2Ev
U _ZNK4Base5printERSo
0000000000000e1c T _ZNK7Derived5printERSo
0000000000201168 V _ZTI4Base
00000000002011a0 V _ZTI7Derived
0000000000000f59 V _ZTS4Base
0000000000000f68 V _ZTS7Derived
0000000000201140 V _ZTV4Base
0000000000201178 V _ZTV7Derived

最大的区别,这里的 U_ZNK4Base5printERSo,很显然,就是基类 Baseprint 方法。虽然对链接的知识已经忘了不少(得去补课了),但回顾这 4 张符号表,还是可以大致看出为啥析构函数单独分离出去后信息发生变化。

最开始析构函数实现写在 .h 文件里时,未定义的符号有两个(一个是 typeinfo 一个是 vtable):

  • _ZTI4Base
  • _ZTV4Base

而析构函数独立出去后,未定义的符号:

  • _ZNK4Base5printERSoBase 类的 print 函数。

至少现在我们知道了为啥会有提示错误区别,但是编译器为啥这么干,还是不清楚。只能说,从经验的角度

优化一下?

所以说明明可以写在源文件里,为何要写在头文件中?写在源文件里至少还能帮助调试。当然这是出自 C++er 的直觉:要是内联了呢?

于是回到最初的模式(Base 虚析构函数在头文件实现,print 不实现),开 -O2 编译,符号表:

1
2
3
4
5
6
7
8
0000000000000b30 W _ZN7DerivedD0Ev
0000000000000b20 W _ZN7DerivedD1Ev
0000000000000b20 W _ZN7DerivedD2Ev
0000000000000b00 T _ZNK7Derived5printERSo
U _ZTI4Base
0000000000200cb0 V _ZTI7Derived
0000000000000bc0 V _ZTS7Derived
0000000000200cc8 V _ZTV7Derived

相比默认的(-O0 编译):

1
2
3
4
5
6
7
8
9
10
11
12
0000000000000cda W _ZN4BaseD0Ev
0000000000000ca4 W _ZN4BaseD1Ev
0000000000000ca4 W _ZN4BaseD2Ev
0000000000000d42 W _ZN7DerivedD0Ev
0000000000000d00 W _ZN7DerivedD1Ev
0000000000000d00 W _ZN7DerivedD2Ev
0000000000000c20 T _ZNK7Derived5printERSo
U _ZTI4Base
0000000000201058 V _ZTI7Derived
0000000000000dc8 V _ZTS7Derived
U _ZTV4Base
0000000000201030 V _ZTV7Derived

首先前三个 Base 的符号(构造函数)被直接内联了。_ZTV4Base 也没了(虚表?),编译 main.cc 报错信息也少了:

$ g++ main.cc -std=c++11 -L. -lbase -O2
./libbase.so: undefined reference to `typeinfo for Base’
collect2: error: ld returned 1 exit status

也就是说 _ZTV4Base 实际上就是 vtableV 代表 vtable),而另一个保留的 _ZTI4Base 则是类型信息(I 代表 typeinfo)。

看来我作为 C++er 的直觉还是对的,内联了导致虚析构函数就是个普通函数一样。但如果把虚析构函数给独立出去,那么开不开 -O2 优化,结果都一样。

name mangling 还原

其实 nm 已经提供了还原功能了,加上 -C 选项即可(这是 -O2 优化+析构函数在头文件里+print 函数未定义):

1
2
3
4
5
6
7
8
9
$ nm -C libbase.so | egrep "(Base|Derived)"
0000000000000b30 W Derived::~Derived()
0000000000000b20 W Derived::~Derived()
0000000000000b20 W Derived::~Derived()
0000000000000b00 T Derived::print(std::ostream&) const
U typeinfo for Base
0000000000200cb0 V typeinfo for Derived
0000000000000bc0 V typeinfo name for Derived
0000000000200cc8 V vtable for Derived

PS:嗯,前面的内容就当踩坑了……懒得改……

此外,也可以看到析构函数放在源文件里时符号表多了:

1
2
3
0000000000000d70 T Base::~Base()
0000000000000d60 T Base::~Base()
0000000000000d60 T Base::~Base()

对应的符号是以 D 结尾的,即 D0/D1/D2。至于为啥有三个析构函数我也不知道……对比了下普通类,只有两个析构函数(D1/D2)。

为何析构函数的符号有无会导致结果变化

从上述分析可知,析构函数如果放在头文件里,无论是否内联优化,最终符号表里都只是 vtable 缺失。我大致猜测是,仅有一个虚函数的符号没有定义。

于是修改 Base 类的实现,加一个有实现的虚函数 f

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// base.h
#pragma once

#include <iostream>

struct Base {
Base() = default;
Base(const Base&) = delete;
Base& operator=(const Base&) = delete;
virtual ~Base() {}

virtual void f() const;

virtual std::ostream& print(std::ostream& os) const;

friend std::ostream& operator<<(std::ostream& os, const Base& base) {
return base.print(os);
}
};
1
2
3
4
// base.cc
#include "base.h"

void Base::f() const {}

结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
$ g++ -o libbase.so base.cc derived.cc -std=c++11 -fPIC -shared -O2
$ nm -C libbase.so | egrep "(Base|Derived)"
0000000000000d70 W Base::~Base()
0000000000000d60 W Base::~Base()
0000000000000d60 W Base::~Base()
0000000000000de0 W Derived::~Derived()
0000000000000dd0 W Derived::~Derived()
0000000000000dd0 W Derived::~Derived()
0000000000000d50 T Base::f() const
U Base::print(std::ostream&) const
0000000000000db0 T Derived::print(std::ostream&) const
0000000000201038 V typeinfo for Base
0000000000201078 V typeinfo for Derived
0000000000000e68 V typeinfo name for Base
0000000000000e78 V typeinfo name for Derived
0000000000201048 V vtable for Base
0000000000201090 V vtable for Derived

可见,这里 U 不再是符号表,而是虚函数本身。

总结

C++ 的 unsolved symbol 问题其实挺常见的,即使是踩过 N 次坑的我也容易因为一点失误而犯错。本文主要讲述了通过 nm 排查问题的方式,其中如果只有 vtable/typeinfo 缺失这种难以排查的信息,可以尝试加一个带实现的虚函数(比如前文的 f),再来排查符号表。

链接选项 rpath 的应用和原理

前言

在测试和部署 C++ 动态库时,经常遇到的问题就是程序链接到了系统路径下的动态库,有时候 make 编译时链接到本地路径的动态库,但实际 make install 时则会丢失这个依赖。本文将要介绍的就是一种通用解决方法,使用 RPATH 来绑定链接路径。

简单动态库编译和使用示例

给出以下示例库:

1
2
3
4
5
6
7
8
9
10
11
12
// foo.h
#pragma once

#ifdef __cplusplus
extern "C" {
#endif

void foo();

#ifdef __cplusplus
}
#endif
1
2
3
4
5
// foo.cc
#include "foo.h"
#include <stdio.h>

void foo() { printf("foo\n"); }

生成动态库 libfoo.so:

1
g++ -o libfoo.so -fPIC -shared foo.cc

然后给出调用代码:

1
2
3
4
5
6
7
// main.cc
#include "foo.h"

int main(int argc, char* argv[]) {
foo();
return 0;
}

编译时指定链接当前目录:

1
2
3
4
5
$ gcc main.c -L . -lfoo
$ ./a.out
foo
$ ldd a.out | grep libfoo
libfoo.so (0x00007fc2010f7000)

至此,一切正常。

依赖动态库的动态库

实际编写程序时,往往会依赖一些第三方库来避免重复造轮子。比如,这里我们要写一个库依赖于 libfoo.so。

目录层次:

1
2
3
4
5
6
7
8
9
10
.
├── include
│   └── bar.h
├── src
│   └── bar.cc
└── thirdparty
├── include
│   └── foo.h
└── lib
└── libfoo.so

然后编译 libbar.so:

1
g++ -o libbar.so -fPIC -shared src/bar.cc -I include/ -I thirdparty/include/ -L thirdparty/lib/ -lfoo

问题来了,编译出的 libbar.so 找不到 libfoo.so 的依赖:

1
2
$ ldd libbar.so | grep foo
libfoo.so => not found

当然,这样的话,你编译依赖 libbar.so 的程序时会直接失败,从而提醒你去寻找依赖的 libfoo.so:

/usr/bin/ld: warning: libfoo.so, needed by ./libbar.so, not found (try using -rpath or -rpath-link)
./libbar.so: undefined reference to `foo’

注意这里我们第一次见到 rpath 这个概念。

但是问题更大的是,假如 libfoo.so 是一个旧版的库,而有个其他用户完全无视影响,直接将 libfoo.so 安装到了系统目录,比如 /usr/lib64 下面。这样,你的程序依赖的 libbar.so 将会找到系统目录下旧的 libfoo.so,而不是你自己维护的新版。如果新版 libfoo.so 的 ABI 发生了改变而 API 不变,比如这里 C 库变成了 C++ 库:

1
2
3
4
5
// foo.h
#pragma once

void foo(int i = 0);

1
2
3
4
5
6
// foo.cc
#include "foo.h"
#include <stdio.h>

void foo(int i) { printf("foo: %d\n", i); }

API 兼容指的是,调用 foo() 仍然合法,但是由于 C++ 的 name mangling,带有默认参数的 foo 对应的符号发生了变化,因此 foo 可能还会出现这样的错误(main.cc 仅仅是调用 bar() 函数,这里就不贴代码了):

1
2
3
4
$ g++ main.cc -L . -lbar
./libbar.so: undefined reference to `foo(int)'
collect2: error: ld returned 1 exit status

查看 libbar.so 的依赖就什么都明白了:

1
2
3
$ ldd libbar.so | grep libfoo
libfoo.so => /usr/lib64/libfoo.so (0x00007efd6daf1000)

原因是 foo 的函数签名变成了 void foo(int),而链接到的动态库却是全局的 libfoo.so。简单的解决方式是,将本地库的路径加入 LD_LIBRARY_PATH 中:

1
2
3
4
5
6
$ export LD_LIBRARY_PATH=$PWD/thirdparty/lib:$LD_LIBRARY_PATH
$ g++ main.cc -L . -lbar
$ ./a.out
bar
foo: 0

ldd 也能查看 libbar.so 依赖的 libfoo.so 不再是全局的,而是本地的。但问题是,如果发布单独的 libbar.so 给用户,而用户又因为某些原因无法升级全局的 libfoo.so,那么每次用户都要手动设置 LD_LIBRARY_PATH

此时,另一种解决方法刚好能避免这个问题,也就是使用 rpath。

rpath 的使用

rpath 即 runtime path,运行时路径。既可以指定相对路径也可以指定绝对路径。

编译方式:

1
2
3
4
5
6
$ g++ -o libbar.so -fPIC -shared src/bar.cc \
-I include/ -I thirdparty/include/ \
-L thirdparty/lib/ -lfoo -Wl,-rpath=thirdparty/lib/
$ ldd libbar.so | grep foo
libfoo.so => thirdparty/lib/libfoo.so (0x00007f8319965000)

注意最后的 -Wr,-rpath 指定的是动态库的路径。看似和 -L 重复,实际不然。-L 指定的是编译时链接的 libfoo.so 路径,而 -Wl,-rpath 指定的是(libbar.so)运行时链接的 libfoo.so 路径。这里指定的是相对路径。

因此如果我们安装 libbar.so 到全局又不影响全局的 libfoo.so,比如安装到 /usr/lib64 下面:

1
2
$ sudo cp libbar.so /usr/lib64

我们继续编译 libbar.so 的使用程序:

1
2
3
4
5
6
7
8
$ g++ main.cc -lbar
$ ldd a.out | grep foo
libfoo.so => thirdparty/lib/libfoo.so (0x00007f5d3d79f000)
$ ldd a.out | grep bar
libbar.so => /usr/lib64/libbar.so (0x00007fed07e78000)
$ ls /usr/lib64/libfoo.so
/usr/lib64/libfoo.so

这样想发布依赖高版本 libfoo.so 的 libbar.so 时,用户只需要在编译和运行时,相对路径 thirdparty/lib 下面有高版本 libfoo.so 就行了,无需覆盖全局路径下的低版本 libfoo.so。

注意如果换个路径运行 a.out,由于 rpath 指定的是相对路径,此时会找不到 libfoo.so。

所以 rpath 指定绝对路径的做法也是比较常见的,比如编译 libbar.so 时将 libfoo.so 置于不会冲突的系统目录:

1
2
3
4
5
6
7
8
$ sudo mkdir -p /usr/lib64/foo-1.1
$ sudo cp thirdparty/lib/libfoo.so /usr/lib64/foo-1.1
$ g++ -o libbar.so -fPIC -shared src/bar.cc \
-I include/ -I thirdparty/include/ \
-L /usr/lib64/foo-1.1 -lfoo -Wl,-rpath=/usr/lib64/foo-1.1
$ ldd libbar.so | grep foo
libfoo.so => /usr/lib64/foo-1.1/libfoo.so (0x00007f83df270000)

那么用户部署时,只需要将 libfoo.so 放置在 /usr/lib64/foo-1.1 下面就行,这里的 1.1 用于标识版本号。由于该目录并不会被自动连接,从而防止了其他程序自动链接到这个版本的 libfoo.so。

回到现代,使用 CMake

但凡稍有规模的程序,直接使用 GCC 编译来构建项目是难以维护的。即使有了 Makefile,管理和维护起来还是相对麻烦。C++ 缺乏类似 Maven 那样的构建系统,但退而求其次,CMake 已经成为了事实上的 C++ 构建通用解决方案。(虽然早期流行的 Autotools 仍然有一定市场)

以一个极简的 CMakeLists.txt 为例,将 rpath 指定为相对路径

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
cmake_minimum_required(VERSION 2.8.12)
project(Bar CXX)

# 设置 find_path/find_library 查找的根目录,默认会从 include 以及 lib 子目录查找
set(CMAKE_PREFIX_PATH "${PROJECT_SOURCE_DIR}/thirdparty")
find_path(FOO_INCLUDE_DIR NAMES foo.h)
find_library(FOO_LIB NAMES libfoo.so)

add_library(bar SHARED src/bar.cc)
include_directories(./include ${FOO_INCLUDE_DIR})
target_link_libraries(bar ${FOO_LIB})

# 设置 rpath,这里是绝对路径
set(CMAKE_INSTALL_RPATH "${PROJECT_SOURCE_DIR}/thirdparty/lib")

# 安装到 lib 子目录,该相对路径是相对 CMAKE_INSTALL_PREFIX 而言的
install(TARGETS bar LIBRARY DESTINATION lib)

PS:说是回到现代,我这个 CMake 还是老式的风格,现代 CMake 又是另一个话题了,不熟悉 CMake 的话,可以直接从现代 CMake 学起。

当前目录层次:

1
2
3
4
5
6
7
8
9
10
11
12
13
.
├── CMakeLists.txt
├── include
│   └── bar.h
├── main.cc
├── src
│   └── bar.cc
└── thirdparty
├── include
│   └── foo.h
└── lib
└── libfoo.so

使用 CMake 构建项目:

1
2
3
4
5
$ mkdir _builds && $ cd _builds/
$ cmake .. -DCMAKE_INSTALL_PREFIX=$PWD/..
$ make
$ make install

之后目录层次(忽略中间目录 _builds):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
.
├── CMakeLists.txt
├── include
│   └── bar.h
├── lib
│   └── libbar.so
├── main.cc
├── src
│   └── bar.cc
└── thirdparty
├── include
│   └── foo.h
└── lib
└── libfoo.so

类似地,为了部署的话,可以将 libfoo.so 部署到系统目录 /usr/lib64 的子目录。

也可以修改成相对路径:

1
2
set(CMAKE_SHARED_LINKER_FLAGS ${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath,'$ORIGIN/thirdparty')

其中 $ORIGIN 会被替换成动态库所处的绝对路径,也就是说只要 libfoo.so 处于 libbar.so 同级目录 thirdparty 下面,如下图所示:

1
2
3
4
5
.
├── libbar.so
└── thirdparty
└── libfoo.so

之后 libbar.so 就会链接到 thirdparty/libfoo.so,并且都是绝对路径。

Linux 动态库查找路径

最后一节,以理论来结尾。前文侧重实践,有了实践作为基础,回过头来看原理就更有体会了。

一个典型的 C/C++ 程序的构建流程是:预处理,汇编,编译,链接。而执行链接的程序其实是 ld,通常编译器比如 GCC 都会自动调用 ld 去进行链接,用户不必关注其中的细节。而 ld 查找动态库的顺序是:

  1. rpath 指定的目录;
  2. 环境变量 LD_LIBRARY_PATH 指定的目录;
  3. runpath 指定的目录;
  4. /etc/ld.so.cache 缓存文件,通常包含 /etc/ld.so.conf 文件编译出的二进制俩别哦(比如 CentOS 上,该文件会使用 include 从而使用 ld.so.conf.d 目录下面所有的 *.conf 文件,这些都会缓存在 ld.so.cache 中)
  5. 系统默认路径,比如 /lib/usr/lib

在编译时若使用 -z nodefaultlib 选项编译,则会跳过 4 和 5。至于 runpath,和 rpath 类似,都是二进制(ELF)文件的动态 section 属性(分别为 DT_RUNPATHDT_RPATH),唯一区别就是是否优先于 LD_LIBRARY_PATH 来查找。这里就不详述了。

总结

至此,读者对如何编译/部署动态库,以及动态库之间的依赖关系应该有了一定的认识。

相比而言,静态链接,静态链接部署简单,像 Golang 这种语言直接全部静态链接,受到了不少用户的青睐,而且占用体积大在现代已经几乎不再是需要特别考虑的问题。

但动态库有动态库的好处,比如在大型项目有多个组件依赖时,如果全部静态链接,则每次修改依赖的模块,都要将主模块重新编译一遍,对于 C++ 这种编译速度可能会非常耗时的语言是灾难性的。

另外,提供插件式接口给解释型语言(比如 Python 和 PHP)来调用时,动态库是必须的,解释器可以动态加载动态库。如果使用静态链接,恐怕没人愿意每换一个插件就要将解释器重新编译一遍。

Kafka源码阅读12: 高性能计时器SystemTimer

前言

前一篇阅读了时间轮 TimingWheel 的实现,遗留了两个重要问题:

  1. 时间轮中被插入延迟队列的桶,何时被移除?
  2. 高层时间轮运转时,定时任务何时被插入低层时间轮?

实际上,在 kafka.utils.timer 包的类中,真正暴露给其它包的只有 SystemTimer,而且注解为 @threadsafe(线程安全),时间轮 TimingWheel 只不过是它的一个字段,本身注解也是 @nonthreadsafe(非线程安全)。SystemTimer 实现了接口 Timer,是基于 Kafka 时间轮设计的高性能定时器。

构造

字段含义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class SystemTimer(executorName: String,
tickMs: Long = 1,
wheelSize: Int = 20,
startMs: Long = Time.SYSTEM.hiResClockMs) extends Timer {

// 线程池(任务执行器),固定大小为 1,也就是同时只能执行最多一个任务
private[this] val taskExecutor = Executors.newFixedThreadPool(1, new ThreadFactory() {
// 自定义线程创建方式:非守护线程,指定 executorName 作为线程名后缀
def newThread(runnable: Runnable): Thread =
KafkaThread.nonDaemon("executor-"+executorName, runnable)
})

private[this] val delayQueue = new DelayQueue[TimerTaskList]()
private[this] val taskCounter = new AtomicInteger(0)
private[this] val timingWheel = new TimingWheel(
tickMs = tickMs,
wheelSize = wheelSize,
startMs = startMs,
taskCounter = taskCounter,
delayQueue
)

// 读写锁,保护时间轮运转(tick)时的相关数据结构
private[this] val readWriteLock = new ReentrantReadWriteLock()
private[this] val readLock = readWriteLock.readLock()
private[this] val writeLock = readWriteLock.writeLock()

主构造器 4 个参数第一个用于指定线程名称,后面三个用于构造时间轮。此外,所有时间轮(包括各个桶)共享一个延迟队列和任务计数器。多层时间轮共享的延迟队列就是这里的 delayQueue,调用 poll 时会将过期的桶弹出队列。

细节 1:高精度时间戳计时

注意到 startMs 是通过 System.nanoTime() 转换得到的高精度纳秒级时间戳:

1
2
3
public interface Time {

Time SYSTEM = new SystemTime();
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public class SystemTime implements Time {

@Override
public long milliseconds() {
return System.currentTimeMillis();
}

@Override
public long hiResClockMs() {
return TimeUnit.NANOSECONDS.toMillis(nanoseconds());
}

@Override
public long nanoseconds() {
return System.nanoTime();
}

之所以使用 nanoTime 是为了高精度计时,但是由于纳秒级时间戳超过了 64位 整型能表达的上限,所以得到的是溢出值(还有可能为负数),只能用于计算两个时间戳的时间间隔,而不能用作时间戳。因此在记录时间戳(比如 Produce 请求得到 LogAppendTime 时)以及对时间间隔精确性不敏感的地方都是用的 currentMilliseconds 方法计时。

细节 2:KafkaThread

Java 线程池屏蔽了线程的细节,用户只要提供了实现 Runnable 的类,即可通过 executesubmit 方法创建线程。出于灵活性考虑,Java 线程池也支持用户自定义 ThreadFactory 接口,实现 newThread 通过 Runnable 对象创建线程的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static KafkaThread nonDaemon(final String name, Runnable runnable) {
return new KafkaThread(name, runnable, false);
}

public KafkaThread(final String name, Runnable runnable, boolean daemon) {
super(runnable, name);
configureThread(name, daemon);
}

private void configureThread(final String name, boolean daemon) {
setDaemon(daemon);
setUncaughtExceptionHandler(new UncaughtExceptionHandler() {
public void uncaughtException(Thread t, Throwable e) {
log.error("Uncaught exception in thread '{}':", name, e);
}
});
}

这里是通过 KafkaThread 类(位于 org.apache.kafka.common.utils 包下)的工厂方法创建的,关键的是设置了异常处理器,当线程函数中抛出意想不到的异常时,将其写入错误日志。

但是,仅当 Runnable 对象由 execute 执行时才会调用这个处理器,因为 submit 执行 Runnable 会返回 Future<?> 对象,只有调用 Future 对象的 get 方法时才会触发异常,这样用户就可以手动 try-catch 捕获异常,而不用自定义异常处理器。

而在 SystemTimer 中,任务是使用 submit 执行的,并且未处理返回的 Future 对象:

1
taskExecutor.submit(timerTaskEntry.timerTask)

因此,虽然 KafkaThread 设置了异常处理器,但是在这里,定时任务抛出的异常实际上被忽略了。

Timer 接口实现

接口概览

SystemTimer 是基于 TimingWheel 实现的定时器,对外提供的功能即它所实现的接口 Timer:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
trait Timer {
/**
* 添加新的任务到当前执行器(线程池),在任务过期后会执行任务。
* @param timerTask 待添加的任务
*/
def add(timerTask: TimerTask): Unit

/**
* 推进内部时钟,执行任何在走过的时间间隔内过期的任务
* @param timeoutMs
* @return 是否有任务被执行
*/
def advanceClock(timeoutMs: Long): Boolean

// 取得待执行的任务数量
def size: Int

// 关闭定时器服务,待执行的任务将不会被执行
def shutdown(): Unit
}

其中 sizeshutdown 的实现很简单,分别是取得 taskCounter 的值以及关闭线程池。

1
2
3
4
5
def size: Int = taskCounter.get

override def shutdown() {
taskExecutor.shutdown()
}

add

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def add(timerTask: TimerTask): Unit = {
readLock.lock()
try {
// 通过任务的延时加上当前时间得到延时的具体时刻,作为定时任务的过期时间
addTimerTaskEntry(new TimerTaskEntry(timerTask, timerTask.delayMs + Time.SYSTEM.hiResClockMs))
} finally {
readLock.unlock()
}
}

private def addTimerTaskEntry(timerTaskEntry: TimerTaskEntry): Unit = {
// 尝试将任务加入时间轮
if (!timingWheel.add(timerTaskEntry)) {
// 仅当 任务已经过期 或者 任务主动取消 才会进入此分支
if (!timerTaskEntry.cancelled) // 任务过期则执行任务
taskExecutor.submit(timerTaskEntry.timerTask)
}
}

其实就是将任务扔进时间轮中,添加失败只有可能是过期或者主动取消,这里额外判断了是否任务主动取消。

唯一值得注意的是这里用了读锁,按照常理,add 并不是读操作而是写操作,为什么是读锁呢?读锁意味着可以多线程同时调用 add 时无需上锁。这是因为 TimingWheel.add 是线程安全的,回顾下时间轮添加任务的流程:

  1. 判断任务是否被取消

    任务绑定的 Entryprivate[this] 修饰的,也就是仅有当前对象能访问。因此只要不是两个相同任务,那么这个判断是线程安全的

  2. 判断过期时间处于那个桶,是否需要加入更高一级时间轮

    所有桶的时间范围由 currentTime(即取整后的 startMs)、tickMswheelSize 决定的,而 add 方法并不会修改它们

  3. 将任务添加进桶:TimeTaskList.add 内部用内置锁保护了,线程安全;

  4. 设置桶的过期时间:调用原子变量的 getAndSet 方法,也是线程安全的。

保证线程安全的策略是要么不修改内部状态,要么调用那些线程安全的方法,因此允许并发地 add

advanceClock

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def advanceClock(timeoutMs: Long): Boolean = {
// 尝试在 timeoutMs 内取出完成的任务
var bucket = delayQueue.poll(timeoutMs, TimeUnit.MILLISECONDS)
if (bucket != null) { // 取出了过期的 bucket
writeLock.lock()
try {
while (bucket != null) {
// 推进当前时间轮,内部可能会递归推进更高一层时间轮,currentTime 被修改
timingWheel.advanceClock(bucket.getExpiration())
// 取出 bucket 所有任务节点,将其传入 reinsert 方法
bucket.flush(reinsert)
// 非阻塞地取出任务,将当前时点所有过期的 bucket 全部取出
bucket = delayQueue.poll()
}
} finally {
writeLock.unlock()
}
true
} else { // 没有 bucket 过期
false
}
}

推进 timeoutMs 毫秒,尽可能取出此时所有过期的 bucket(问题 1 解决),然后调用 flush 将 bucket 中所有任务节点传入 reinsert

1
2
3
4
5
6
7
8
9
10
11
def flush(f: (TimerTaskEntry)=>Unit): Unit = {
synchronized {
var head = root.next
while (head ne root) {
remove(head)
f(head)
head = root.next
}
expiration.set(-1L)
}
}

TimerTaskList.flush 方法很简单,用内置锁保护,然后依次删除链表(bucket)所有节点,并应用到函数上,最后重置 expiration 以保证下次有任务加入该 bucket 时,该 bucket 会被加入延迟队列。

1
private[this] val reinsert = (timerTaskEntry: TimerTaskEntry) => addTimerTaskEntry(timerTaskEntry)

reinsert 则是尝试将这些从 bucket 中删除的节点重新加入时间轮。

这里需要注意, bucket 过期时,内部节点也都过期了,因为 bucket 的过期时间是所有内部过期时间取整后得到的被 tickMs 整除的值。那为什么要这么做呢?

回顾我们开始提出的问题 2,如果取出的这个 bucket 是属于高层时间轮的,由于高层时间轮精度不够,此时 bucket 可能并未过期。

举个两层时间轮的例子(单位:毫秒):

层次 buckets
1 [0,1) [1,2)
2 [0,2) [2,4)

初始状态下,延时为 3 的任务被加入 [2,4),调用 advanceClock(2) 后,时间轮变成了

层次 buckets
1 [2,3) [3,4)
2 [2,4) [4,6)

第 2 层的 [2,4) 被取出,然后延时为 3 的任务被取出,此时调用 reinsert 就会将其加入第 1 层的 [3,4),而不是立刻判断它过期。至此,问题 2 解决,从高层时间轮降级到底层时间轮被隐藏在了这句不起眼的 bucket.flush(reinsert) 中。

总结

本章阅读了 Kafka 高精度定时器 SystemTimer 的实现,它管理了延迟队列和时间轮,每次加入定时任务将任务扔进时间轮中,并将任务节点所在的 bucket 扔进延迟队列中。它本身的推进是通过延迟队列进行的,每次推进一段时间,尽可能取出到期的 bucket,并依次取出 bucket 的所有任务节点。通过将取出的任务节点重新加入到时间轮中,可能会将高层时间轮中过期任务转移到底层时间轮中。

此外,对于到期的任务,SystemTimer 使用仅包含单线程的线程池执行,若推进时又多个任务节点被取出,会等待前一个任务对应的线程完成后才会继续执行该任务(复用这个线程)。

Kafka源码阅读11: 时间轮TimingWheel

前言

前一章阅读了各种延迟操作类的基类 DelayedOperation,而延迟操作对象会传入 DelayedOperationPurgatory,查看其构造参数:

1
2
3
4
5
6
purgatoryName: String,
timeoutTimer: Timer,
brokerId: Int = 0,
purgeInterval: Int = 1000,
reaperEnabled: Boolean = true,
timerEnabled: Boolean = true

ReplicaManager 中是调用 apply 方法构造的,这里的 timer 使用 util.timer.SystemTimer

1
val timer = new SystemTimer(purgatoryName)

SystemTimer内部一个重要字段就是时间轮 TimingWheel 对象:

1
2
3
4
5
6
7
private[this] val timingWheel = new TimingWheel(
tickMs = tickMs,
wheelSize = wheelSize,
startMs = startMs,
taskCounter = taskCounter,
delayQueue
)

设计思路

实现在 utils/timer/TimingWheel.scala 中,这是 Kafka 精心设计的时间轮,因此关于该类的说明有长达 70 多行,这里首先阅读其设计思路。

简单时间轮

简单时间轮通常是时间任务桶的循环链表。令 u 为时间单元,一个大小为 n 的时间轮有 n 个桶,能够持有 n * u 个时间间隔的定时任务。

每个桶持有进入相应时间范围的定时任务。最开始,第一个桶持有 [0, u) 范围的任务,第二个桶持有 [u, 2u) 范围的任务……第 n 个桶持有 [u * (n - 1), u * n) 范围的任务。每过一个时间单元 u,定时器会 tick 并移动到下个桶,然后其中所有的定时任务都会过期。由于任务已经过期,此时定时器不会插入任务到当前桶中。定时器会立刻运行过期的任务。因为空桶在下一轮是可用的,所以如果当前的桶对应时间 t,那么它会在 tick 后变成 [t + u * n, t + (n + 1) * u) 的桶。

时间轮的插入/删除(即启动/停止定时器)的时间复杂度是 O(1),而基于优先队列的定时器,比如 java.util.concurrent.DelayQueuejava.util.Timer 插入/删除的时间复杂度是 O(log n)


本质上时间轮就是个哈希表,因此插入/删除的时间复杂度是 O(1),而哈希表的 value 类型是链表,插入/删除的时间复杂度也是 O(1),因此将定时任务 TimerTaskEntry 插入到时间轮/从时间轮中删除的时间复杂度也是 O(1)

分层时间轮

简单时间轮的主要缺点是它假设定时器请求是在从当前时刻开始的 n * u 时间间隔内,如果定时器请求超出了这个间隔就会产生溢出。分层时间轮会处理这种溢出,它以层次来组织时间轮,最底层的精度更高,层数越高,表示的精度更低。如果某一层时间轮的精度是 u,大小是 n,则更高一层的精度是 n * u。每一层的溢出会被委托给更高层的时间轮。当更高层的时间轮 tick 时,它会把定时任务插入到更底层。溢出的时间轮会按照需求来创建。当溢出的时间轮的桶过期时,其中所有任务会重新递归地插入到定时器中,之后这些任务会被移动到精度更高的时间轮中或者被执行。设 m 是时间轮的数量,则插入(启动定时器)的时间复杂度是 O(m),相比起系统中请求的数量,通常是小很多的。而删除(停止定时器)的时间复杂度仍然是 O(1)


像时钟就是一个典型的三层时间轮,秒针能表示 0 到 59 秒,但是对 60 秒以上则需要分针进一步表示,再进一步即时针,一共能表示的时间范围为 0 到 43199 秒,精度为 1 秒。从秒针到分针到时针,表示精度是依次降低的,秒针精度为 1 秒,有 60 格,因此分针精度是 1 * 60 = 60 秒,类似地,时钟精度是 3600 秒。而上文用到的 tick 一词,则对应秒针/分针/时针的走动。

时间轮的每个时间间隔都对应了一个桶(bucket),即定时任务链表 TimerTaskList。根据每个定时任务的 timeout(过期时间),决定将任务分配给那个桶。

示例

u = 1, n = 3,设起始时刻是 c,则各层次的桶为

层次 精度
1 [c,c] [c+1,c+1] [c+2,c+2] 1
2 [c,c+2] [c+3,c+5] [c+6,c+8] 3
3 [c,c+8] [c+9,c+17] [c+18,c+26] 9

PS:这里沿用了代码注释里的表示,即闭区间,而前面讲述原理时都是左闭右开区间,两者是等价的,只是表示不一致。

c+1 时刻,桶 [c,c][c,c+2][c,c+8]过期了,之后:

  • 1 层的时钟移动到 c+1,并且创建新的桶 [c+3,c+3]
  • 2、3 层的时钟仍然在 c 处,因为他们的精度是 3 和 9。

此时各层次的桶为:

层次 精度
1 [c+1,c+1] [c+2,c+2] [c+3,c+3] 1
2 [c,c+2] [c+3,c+5] [c+6,c+8] 3
3 [c,c+8] [c+9,c+17] [c+18,c+26] 9

注意,桶 [c,c+2] 不会接收任何任务,因为此时时刻是 c+1,只有 timeout 为 c+1c+2 才会被分配到该桶,然而 1 层的两个桶 [c+1,c+1] [c+2,c+2] 会优先接收任务。类似地,3 层的 [c+1,c+8] 也不会接收任何任务,因为这个范围被 2 层的桶覆盖了。

依次类推,在 c+3 时刻,2 层也会创建新的桶,各层次的桶为:

层次 精度
1 [c+3,c+3] [c+4,c+4] [c+5,c+5] 1
2 [c+3,c+5] [c+6,c+8] [c+9,c+11] 3
3 [c,c+8] [c+9,c+17] [c+18,c+26] 9

PS:这里源码的注释说 3 层的第 3 个桶是 [c+8,c+11],看了下,大概是注释错误?

实现

TimeWheel 的字段

主构造器的字段

名称 类型 说明
tickMs Long 每 tick 一次经过的毫秒数,即前文的时间单元 u
wheelSize Int 时间轮大小,即前文的桶数 n
startMs Long 毫秒级时间戳
taskCounter AtomicInteger 任务数量,即所有桶(链表)中的节点数量之和
queue DelayQueue[TimerTaskList]

注意到这里还有个 DelayQueue 作为辅助,具体作用之后再看。

通过上述主构造参数可以计算出以下私有字段(private[this],可以被包内其他类访问)

1
2
3
4
5
6
7
8
9
10
// 当前时间轮的整个时间跨度,即更高一层时间轮的 tickMs
private[this] val interval = tickMs * wheelSize
// 创建 wheelSize 个桶(定时任务链表)
private[this] val buckets = Array.tabulate[TimerTaskList](wheelSize) { _ => new TimerTaskList(taskCounter) }

// 向下取整,使起始时间戳能被 tickMs 整除
private[this] var currentTime = startMs - (startMs % tickMs) // rounding down to multiple of tickMs

// 高一层时间轮,用来保存超过 interval 的任务
@volatile private[this] var overflowWheel: TimingWheel = null

注意这里做了取整,因此左闭右开区间 [currentTime, currentTime + tickMs) 即时间轮第一个桶的范围。

通过 addOverflowWheel 创建高一层时间轮:

1
2
3
4
5
6
7
8
9
10
11
12
13
private[this] def addOverflowWheel(): Unit = {
synchronized {
if (overflowWheel == null) { // Double-Checked Locking 模式
overflowWheel = new TimingWheel(
tickMs = interval, // 仅有此参数和之前不同,见分层时间轮一节的解释
wheelSize = wheelSize,
startMs = currentTime,
taskCounter = taskCounter,
queue
)
}
}
}

添加定时任务

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def add(timerTaskEntry: TimerTaskEntry): Boolean = {
// 定时任务的过期时间戳
val expiration = timerTaskEntry.expirationMs

if (timerTaskEntry.cancelled) {
// Entry 绑定的 TimerTask 调用了 cancel() 方法主动将 Entry 从链表中移除
false
} else if (expiration < currentTime + tickMs) {
// 过期时间在第一个桶的范围内,表示已经过期,此时无需加入时间轮
false
} else if (expiration < currentTime + interval) {
// 过期时间在当前时间轮能表示的时间范围内,加入到其中一个桶
// 注意按照这个算法,第一个桶的时间范围是 [c+u,c+u*2),因为 [c,c+u) 范围内被视为已过期
// 而且第一个桶对应 buckets 的下标并不一定是 0,因为数组只是作为循环队列的存储方式,起始下标无所谓
val virtualId = expiration / tickMs
val bucket = buckets((virtualId % wheelSize.toLong).toInt)
bucket.add(timerTaskEntry)

// 设置过期时间,这里也取整了,即可以被 tickMs 整除
if (bucket.setExpiration(virtualId * tickMs)) { // 仅在新的过期时间和之前的不同才返回 true
// 由于进行了取整,同一个 bucket 所有节点的过期时间都相同,因此仅在 bucket 的第一个节点加入时才会进入此 if 块
// 因此保证了每个桶只会被加入一次到 queue 中,queue 存放所有包含定时任务节点的 bucket
// 借助 DelayQueue 来检测 bucket 是否过期,bucket 时遍历即可取出所有节点
queue.offer(bucket)
}
true
} else {
// 过期时间在当前时间轮表示的范围之外,即溢出,需要创建高一层时间轮来加入
if (overflowWheel == null) addOverflowWheel() // 双重检查上锁的第一层检查
overflowWheel.add(timerTaskEntry) // 注意高一层时间轮也可能无法容纳,因此可能会递归创建更高层级的时间轮
}
}

主要知识点在前面的设计思路中都讲到了,可以看到 DelayQueue 对象 queue 在时间轮的作用是,保存包含定时任务节点的桶,桶可以来自不同层次的时间轮,当然,所有层次时间轮也共享这个队列。

TimeWheel 本身没有实现 tick 功能,而是借助延迟队列 DelayQueue 来实现时间的推移,假设有 M 个定时任务分布在 N 个桶中,那么插入的时间复杂度为 O(M + N * log N),其中 M >= N。如果把任务全存到延迟队列中,那么插入的时间复杂度为 O(M * log M),因此 Kafka 时间轮的优化是有意义的。

比如对于 1 层时间轮的 3 个桶:[0,4)[4,8)[8,12),有以下过期时间的定时任务:

1
1,2,3,8,9,10,11

那么会向 queue 中插入 2 个桶,然后利用 queue 依次弹出 2 个桶,通过遍历弹出每个桶的节点:

  • 时刻 0:弹出节点 1,2,3;
  • 时刻 8:弹出节点 8,9,10,11。

删除定时任务

再再再次回顾,延迟操作 DelayedOperation 对象,继承自定时任务 TimerTask 接口,而 TimerTask 会绑定一个 TimerTaskEntry 节点,每个节点位于唯一对应的链表 TimerTaskList (即 bucket)上。

定时任务的删除即调用 TimerTaskList.remove 方法(TimerTaskEntry.remove 也会调用该方法),有以下几种可能:

  • 延时操作对象主动调用 cancel 和节点解绑,解绑后的节点也无法加入到 bucket 中;
  • 当前 bucket 上的节点被另一个 bucket 调用 add 方法,此时会先从当前 bucket 上移除该节点。

时间轮的转动

1
2
3
4
5
6
7
8
def advanceClock(timeMs: Long): Unit = {
if (timeMs >= currentTime + tickMs) { // timeMs 超过了当前 bucket 的时间范围
currentTime = timeMs - (timeMs % tickMs) // 修改当前时间,即原先的第一个桶已经失效

// 若存在更高层的时间轮,则也会向前运转
if (overflowWheel != null) overflowWheel.advanceClock(currentTime)
}
}

总结

本文叙述了 Kafka 分层时间轮的设计思路,并阅读了其源码实现,在 Kafka 这种需要处理大量异步任务(延时请求、定时任务,都可以视为等价的概念)的系统上,基于优先级队列的 DelayQueue 性能不够高,因此 Kafka 借助了时间轮的思想,将同一个时间范围内的异步任务放到一个桶中,进一步将桶放入优先级队列。核心思想是同一个时间区间范围的多个任务,只需要加入一次到优先级队列中。

底层数据结构是:

  • 定长数组实现循环队列,来模拟时间轮;
  • 时间轮的每个 bucket(即数组元素)为链表,链表上每个节点对应一个定时任务;
  • 多层时间轮通过单个时间轮的链表来实现。

顺便,本文留下了一个问题,那就是 queue 调用了 offer 方法将 bucket 加入到队列中,但是在 TimeWheel.scala 源码中,没有看到 queue 调用 poll 方法弹出 bucket。

此外,设计思路部分前文提到的了:

当更高层的时间轮 tick 时,它会把定时任务插入到更底层。

如何降级,在 TimingWheel 中没有体现。

其实这些是在 SystemTimer 中实现的,它进一步包装了 TimingWheel,也是 kafka.utils.timer 包中唯一暴露给外部的类,下一篇文章将会阅读其实现。

Kafka源码阅读10: 延迟操作DelayedOperation

前言

之前阅读了 Produce 和 Fetch 请求的实现,对于需要耗时处理的网络请求,都是利用 DelayedOperationDelayedOperationPurgatory 来进行异步延迟操作,防止阻塞 KafkaRequestHandler 线程。

比如处理 Produce 请求时,ReplicaManager.appendRecords 方法在 ack 为 -1,有数据发送且有至少有一个分区的 append 操作成功时:

1
2
3
val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
val producerRequestKeys = entriesPerPartition.keys.map(new TopicPartitionOperationKey(_)).toSeq
delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)

再比如,处理 Fetch 请求时,ReplicaManager.fetchMessages 方法在 timeout 大于 0,读取本地数据没出错且响应积攒的字节数足够多时:

1
2
3
val delayedFetch = new DelayedFetch(timeout, fetchMetadata, this, quota, isolationLevel, responseCallback)
val delayedFetchKeys = fetchPartitionStatus.map { case (tp, _) => new TopicPartitionOperationKey(tp) }
delayedFetchPurgatory.tryCompleteElseWatch(delayedFetch, delayedFetchKeys)

用法很类似,首先创建 DelayedXXX 对象,然后对每个分区都创建 TopicPartitionOperationKey 对象组成 Seq,将两者传入 purgatory 中进行 tryXXX 操作,从命名和注释可以猜到,这个操作是尝试完全请求,就像 Scala 的 Promise 类的 tryComplete 方法一样,异步操作的常见模式就是下面两个非阻塞操作:

  1. 启动一个任务异步执行,然后当前线程该干嘛干嘛;
  2. 想要确认任务是否执行结束时,看一眼,如果结束了就取得结果。

不过除了确认操作(Operation)是否完成外,还会在没有完成的时候,监控(Watch)相应的延迟操作。

1
def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean
1
2
3
4
5
6
7
/* used by delayed-produce and delayed-fetch operations */
case class TopicPartitionOperationKey(topic: String, partition: Int) extends DelayedOperationKey {

def this(topicPartition: TopicPartition) = this(topicPartition.topic, topicPartition.partition)

override def keyLabel = "%s-%d".format(topic, partition)
}

实际上 Key 可以是任意类型,只要实现了 keyLabel 方法,主题和分区组成的 Key 是用于延迟的 Produce 和 Fetch 操作,而对于其它请求/操作则用的其它类型的 Key,比如 JoinGroup 操作用 group id 和 consumer id 组成 Key。

通过前面的源码阅读可知,Produce 和 Fetch 请求全程都是按照分区去处理,也就是每个分区对应一个类型,然后对这个类型进行 map, filter 等等,所以这里传入 purgatory 的 key 可以唯一标识延迟处理的数据,比如 Fetch 操作中需要处理的 FetchMetadata

1
2
3
case class FetchMetadata(/* 其它字段... */
// key: 分区; value: 获取分区的状态
fetchPartitionStatus: Seq[(TopicPartition, FetchPartitionStatus)])

DelayOperation

主要方法

首先看看 DelayedOperationPurgatory 类及其 tryCompleteElseWatch 方法的完整签名

1
2
3
4
final class DelayedOperationPurgatory[T <: DelayedOperation](/* ... */)
extends Logging with KafkaMetricsGroup {

def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean

参数 1 是泛型参数 T,该类型必须继承自 DelayedOperation,该类型是抽象类

1
2
abstract class DelayedOperation(override val delayMs: Long,
lockOpt: Option[Lock] = None) extends TimerTask with Logging {

设置了毫秒级延时 delayMs 以及可选的锁 lockOpt。要实现一个延迟操作,也就是继承 DelayedOperation 类并重写(override)以下抽象方法:

1
2
3
4
5
6
7
8
9
10
// 回调: 当延迟操作过期时执行, 因此 delayMs 到期时会强制完成
def onExpiration(): Unit

// 回调: 当操作完成时执行
def onComplete(): Unit

// 检查现在操作是否已经完成:
// 1. 已完成, 则调用 forceComplete() 并返回 true 如果 forceComplete 返回 true;
// 2. 否则返回 false
def tryComplete(): Boolean

此外提供了原子 Boolean 类型的字段来标识是否完成

1
2
3
4
5
6
7
8
9
10
11
12
13
private val completed = new AtomicBoolean(false)

def isCompleted: Boolean = completed.get()

def forceComplete(): Boolean = {
if (completed.compareAndSet(false, true)) {
cancel() // 调用基类 TimerTask 的方法取消当前定时任务
onComplete() // 执行派生类自定义的回调
true
} else {
false
}
}

该原子变量是在 forceComplete 中设置为 true 的,可能有多个线程尝试完成同一个任务,由于是执行原子 Boolean 的 CAS 操作,只有第一个线程会返回 true,onComplete() 回调只会被调用一次。

前面的注释也提过,简单查看 DelayedProduceDelayedFetchtryComplete() 实现也能看到,每个代表任务已完成的分支,都会将 forceComplete() 的返回值作为 tryComplete() 的返回值。

此外,基类 TimerTask 继承自 Runnable 接口,因此 DelayedOperation 可以作为线程被启动,执行 run() 方法:

1
2
3
4
override def run(): Unit = {
if (forceComplete())
onExpiration()
}

如果任务未完成,则强制完成(期间会执行 onComplete 回调),并执行 onExpiration 回调。相当于多了个对超时的处理,因此可以猜测会在过期时启动线程来执行超时回调。

maybeTryComplete

该方法是 server 包私有的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
private[server] def maybeTryComplete(): Boolean = {
var retry = false
var done = false
do {
if (lock.tryLock()) {
// 上锁成功, 直接调用 tryComplete
try {
tryCompletePending.set(false)
done = tryComplete()
} finally {
lock.unlock()
}
// 此时可能另外一个线程调用了 `maybeTryComplete` 并将 `tryCompletePending` 置为 true(在此之前已经设为了 false)
// 因此 retry 为 true 就代表这种情况发生, 此时会触发重试条件, 继续 while 循环
retry = tryCompletePending.get()
} else { // 上锁失败, 说明另一个线程持有锁
// 如果此时 `tryCompletePending` 为 true, 那么持锁线程必然看到了 true 并且重试, 那么当前线程可以退出。
// 否则持锁线程正在 `tryComplete`, 此时将其设为 true 因为持锁线程可能在 `tryCompletePending` 设为 true 的时候返回
retry = !tryCompletePending.getAndSet(true)
}
} while (!isCompleted && retry)
done
}

看起来有点绕,分情况讨论。

  1. 单线程:调用 tryComplete 后退出循环,因为自己将 tryCompletePending 置为了 false,解锁后retry 为false,此时和直接 tryComplete 无异;
  2. 双线程:记为 A 和 B,假设 A 上锁成功,且在 tryComplete 检查完成状态的时候 B 上锁失败,那么 B 将 tryCompletePending 置为 true,这会导致两种情况:
    • A tryComplete 成功,代表 onComplete 已被调用,isCompleted 为 true,A 和 B 都会退出循环;
    • A tryComplete 失败,isCompleted 为 false,由于 tryCompletePending 被 B 置为 false,A 的 retry 为 true,而由于不存在其它等待线程,所以 B 在 getAndSet 时得到的值(赋给 retry)也是 true,A 和 B 重新争夺锁,也就是说至少会再调用一次 tryComplete
  3. 三个以上线程:存在 1 个持锁线程和 N 个等待线程(N > 1),getAndSet 是原子操作,也就是说 N 个等待线程只有 1 个等待线程的 retry 会被置为 true,其它线程都因为第一个调用 getAndSet 的等待线程将 tryCompletePending 置为 false 时退出,此时和双线程的情况无异。

核心就是双线程的情况,这种做法是为了针对这种场景:线程 A 检查完成状态的时候,此时还未完成,而线程 B 检查完成状态的时候,虽然实际已经完成了,但由于线程 A 正持有锁,B 不会检查状态。这种做法能让无论线程 A 还是 B,都会再调用一次 tryComplete 检查是否完成了。

举个例子,假设状态由 1 个 Boolean 表示,只有都为 true 时才算完成。如果不这么设计,那就可能出现下面的情况:

时刻 状态 线程 A 线程 B
1 false 上锁
2 false 取得状态 false
3 true 判断状态(已经是旧的状态)是否为 true 上锁失败,不检查状态
4 true 返回 false 返回 false

而现在的做法就是在第 4 步,让线程 A 和 B 再次竞争 tryComplete 的机会,至少有一个线程能检查新的状态。

TimerTask

DelayedOperation.forceComplete 有一个关键的 cancel() 调用来自于基类 TimerTask,从源码注释可知,这个方法是取消 timeout 计时器,即强行停止耗时超过 delayedMs 的延时任务。这个方法是在基类 TimerTask 中实现的。

1
2
3
4
5
6
7
8
private[this] var timerTaskEntry: TimerTaskEntry = null

def cancel(): Unit = {
synchronized {
if (timerTaskEntry != null) timerTaskEntry.remove()
timerTaskEntry = null
}
}

实际上是调用了 TimerTaskEntry.remove 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
private[timer] class TimerTaskEntry(val timerTask: TimerTask, val expirationMs: Long) extends Ordered[TimerTaskEntry] {

@volatile
var list: TimerTaskList = null

def remove(): Unit = {
var currentList = list
// 如果另一个线程将当前节点从一个链表移动到另一个链表,由于 list 会被修改成新链表的引用
// 所以 remove 会失败,因此在这里用 currentList 暂存之前链表的引用,这样就避免锁住整个 list
// 单线程: list.remove(this) 会将 this.list 置为 null, 移除后退出循环;
// 多线程: 如果 this.list 被其它线程修改指向了新的链表, 那么循环会继续, 将该节点从新链表移除
// 一个罕见场景: 线程 B 将该节点从链表 A 移除加入链表 B, 但是在修改节点的 list 之前, 线程 A
// 就移除成功了, 获取的 list 为 null, 退出循环, 之后线程 B 将节点加入链表 B
while (currentList != null) {
currentList.remove(this)
currentList = list
}
}
}

TimerTaskEntryTimerTaskList (定时任务链表)上的一个节点(Entry),内部维护了链表的引用,调用 remove 即将当前节点从链表上移除。

1
2
3
4
5
6
7
8
9
10
11
12
13
def remove(timerTaskEntry: TimerTaskEntry): Unit = {
synchronized { // 保护多线程 remove 和 add
timerTaskEntry.synchronized { // 保护单个节点的 add 和 remove
if (timerTaskEntry.list eq this) { // 确认当前节点还在当前链表上才移除
timerTaskEntry.next.prev = timerTaskEntry.prev
timerTaskEntry.prev.next = timerTaskEntry.next
timerTaskEntry.next = null
timerTaskEntry.prev = null
timerTaskEntry.list = null
taskCounter.decrementAndGet()
}
}
}

可见 TimerTaskList 是双向链表,移除节点时会锁住该节点,然后修改其 prevnext,并维护了原子变量的 taskCounter 记录任务节点的数量。除了 remove 外只有 add 方法会在多线程下造成 race condition,因此要加锁。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def add(timerTaskEntry: TimerTaskEntry): Unit = {
var done = false
while (!done) {
// 如果节点存在于另一个链表中, 将其删除从而保证一个节点只被一个链表拥有
// 这里不加锁是因为 remove 也会调用 synchronized, 会造成死锁, remove 本身
// 会反复重试直到节点的链表为 null
timerTaskEntry.remove()

synchronized { // 保护多线程 add
timerTaskEntry.synchronized { // 保护单个节点的 add 和 remove
if (timerTaskEntry.list == null) { // 确认当前节点已经成功被 remove
val tail = root.prev
timerTaskEntry.next = root
timerTaskEntry.prev = tail
timerTaskEntry.list = this
tail.next = timerTaskEntry
root.prev = timerTaskEntry
taskCounter.incrementAndGet()
done = true
}
}
}
}
}

PS:这里似乎锁住单个节点没必要,因为 addremove 本身都已经上锁了,大概是防止将来有其它方法直接修改节点的字段吧,毕竟节点的字段(这里主要指list)都不是私有的。(虽然我看到的对 list 的修改也就只在 add/remove 方法中)。

总之,这里单独实现的 TimerTaskList,除了实现了线程安全的增删操作外,主要是保证了链表上的节点(对应一个定时任务)是对应唯一的链表的,主要是为了保证节点在从链表 A 迁移到链表 B 时,不会继续留存在 A 中。

TimerTask 会绑定一个 TimerTaskEntry 类型的节点,该节点位于 TimerTaskList 类型的双向链表上,链表包含一个字段:expirationMs,即任务的毫秒级 timeout。任务链表在源码注释中也被称为 bucket(桶),其本身也有一个原子类型的 expiration 字段代表任务链表本身的 timeout,提供了 setter 和 getter,需要注意的是 setter 返回的是 Boolean 而非 Unit,为 true 时则代表 expiration 发生了变化。

TimerTask 在绑定新的 TimerTaskEntry 时,如果和之前的节点不一样,也会将其移除:

1
2
3
4
5
6
7
8
9
10
private[timer] def setTimerTaskEntry(entry: TimerTaskEntry): Unit = {
synchronized {
// if this timerTask is already held by an existing timer task entry,
// we will remove such an entry first.
if (timerTaskEntry != null && timerTaskEntry != entry)
timerTaskEntry.remove()

timerTaskEntry = entry
}
}

而在创建定时任务节点时,会自动和构造参数的 TimerTask 对象绑定:

1
2
3
4
5
private[timer] class TimerTaskEntry(val timerTask: TimerTask, val expirationMs: Long) extends Ordered[TimerTaskEntry] {

// if this timerTask is already held by an existing timer task entry,
// setTimerTaskEntry will remove it.
if (timerTask != null) timerTask.setTimerTaskEntry(this)

总结

延时操作(DelayedOperation)是一个抽象基类,继承自定时任务(TimerTask),由派生类实现以下方法:

  • 任务完成的回调;
  • 任务超时的回调;
  • 非阻塞地确认任务是否完成。

每个延时操作都是一个定时任务(TimerTask),对应一个定时任务节点(TimerTaskEntry),而每个任务节点都是存在一个 bucket(定时任务链表,TimerTaskList)上的。通过线程安全的 removeadd 操作可以让节点从一个 bucket 移动到另一个 bucket,整个过程中节点都始终对应唯一的 bucket,不可能被多个 bucket 共享。这使得任务能安全地在多个 bucket 之间迁移。也就是接下来要阅读的时间轮 TimeWheel

Kafka源码阅读09: Fetch请求

Fetch协议

Fetch API用于为某些分区获取日志,逻辑上它指定主题分区起始offset来取得消息,消息格式参考The Messages Fetch

KafkaApis.handleFetchRequest

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
val versionId = request.header.apiVersion
val clientId = request.header.clientId
val fetchRequest = request.body[FetchRequest]
val fetchContext = fetchManager.newContext(fetchRequest.metadata(),
fetchRequest.fetchData(),
fetchRequest.toForget(),
fetchRequest.isFromFollower()) // replicaId >= 0, 即非负id代表fetch请求来自follower

val erroneous = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]() // 分区 -> 响应
val interesting = mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]() // 分区 -> 请求
if (fetchRequest.isFromFollower()) { // fetch 请求来自于 follower
if (authorize(request.session, ClusterAction, Resource.ClusterResource)) {
// 认证成功,判断请求的每个分区是否存在,若存在则将分区对应的请求加入 interesting 中
// 否则则构造错误响应加入 erroneous
fetchContext.foreachPartition((part, data) => {
if (!metadataCache.contains(part.topic)) {
erroneous += part -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION,
FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
} else {
interesting += (part -> data)
}
})
} else { // 认证失败,对所有分区都构造错误响应加入 erroneous
fetchContext.foreachPartition((part, data) => {
erroneous += part -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED,
FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
})
}
} else { // fetch 请求来自于客户端(消费者),和之前处理一样,认证失败或者分区不存在则构造错误响应
fetchContext.foreachPartition((part, data) => {
if (!authorize(request.session, Read, new Resource(Topic, part.topic)))
erroneous += part -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED,
FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
else if (!metadataCache.contains(part.topic))
erroneous += part -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION,
FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
else
interesting += (part -> data)
})
}

可见主要是调用authorize方法进行 ACL 认证,以及查询metadataCache判断请求的分区是否存在。对于 follower,认证是基于整个请求的,操作是ClusterAction;对于 consumer,认证是基于每个分区的,类型是Read

只有经过认证且存在于metadataCache的分区对应的请求会加入interesting中,其它分区会构造一个默认的不合法响应加入erroneous中。

接下来定义了如下回调函数:

1
2
3
def convertedPartitionData(tp: TopicPartition, data: FetchResponse.PartitionData): FetchResponse.PartitionData

def processResponseCallback(responsePartitionData: Seq[(TopicPartition, FetchPartitionData)])

然后调用ReplicaManager.fetchMessages方法对 interesting 请求进行处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
if (interesting.isEmpty)
processResponseCallback(Seq.empty)
else {
replicaManager.fetchMessages(
fetchRequest.maxWait.toLong, // 最大等待时间,毫秒
fetchRequest.replicaId, // 副本 id,客户端为 Consumer 则为 -1
fetchRequest.minBytes, // 响应中积攒的最小字节数
fetchRequest.maxBytes, // 响应中积攒的最大字节数
versionId <= 2, // maxBytes 字段是从 V3 才引入的,因此判断 API 版本以兼容旧版本请求
interesting, // 通过认证且分区存在的请求
replicationQuota(fetchRequest),
processResponseCallback, // 处理响应的回调
fetchRequest.isolationLevel)

ReplicaManager.fetchMessages

主要实现

方法说明:从 leader 副本取得消息,等待足够数据可以获取。一旦超时或者请求条件被满足则回调被调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def fetchMessages(timeout: Long,
replicaId: Int,
fetchMinBytes: Int,
fetchMaxBytes: Int,
hardMaxBytesLimit: Boolean,
fetchInfos: Seq[(TopicPartition, PartitionData)],
quota: ReplicaQuota = UnboundedQuota,
responseCallback: Seq[(TopicPartition, FetchPartitionData)] => Unit,
isolationLevel: IsolationLevel) {
val isFromFollower = Request.isValidBrokerId(replicaId) // replicaId >= 0 (Follower) 则为 true
// replica id 不为 -2 (debugging) 和 -3 (future local) 则为 true, 即正常 Fetch 请求都只从 leader 获取
val fetchOnlyFromLeader = replicaId != Request.DebuggingConsumerId && replicaId != Request.FutureLocalReplicaId
// replica id 为 -1 (Consumer) 且不为 -3 (future local) 则为 true, 即 Consumer 仅获取已提交的 offsets
val fetchOnlyCommitted = !isFromFollower && replicaId != Request.FutureLocalReplicaId

def readFromLog(): Seq[(TopicPartition, LogReadResult)] = { /* ... */ }

// 从本地消息日志读取结果
val logReadResults = readFromLog() // Seq[(TopicPartition, LogReadResult)]

// 所有分区的 LogReadResult 组成的 Seq
val logReadResultValues = logReadResults.map { case (_, v) => v }
// 总共读取的字节数
val bytesReadable = logReadResultValues.map(_.info.records.sizeInBytes).sum
// 如果存在 LogReadResult 的 error 字段不为 NONE 则为 true, 即存在读取错误
val errorReadingData = logReadResultValues.foldLeft(false) ((errorIncurred, readResult) =>
errorIncurred || (readResult.error != Errors.NONE))

if (timeout <= 0 || fetchInfos.isEmpty || bytesReadable >= fetchMinBytes || errorReadingData) {
// 请求不想等待 or 请求消息为空 or 读取的总字节数超过了最小积攒字节数 or 存在读取错误
// 此时直接生成结果给回调函数处理
val fetchPartitionData = logReadResults.map { case (tp, result) =>
tp -> FetchPartitionData(result.error, result.highWatermark, result.leaderLogStartOffset, result.info.records,
result.lastStableOffset, result.info.abortedTransactions)
}
responseCallback(fetchPartitionData)
} else {
// Map 类型, key 为 TopicPartition, value 为 FetchPartitionStatus
val fetchPartitionStatus = logReadResults.map { case (topicPartition, result) =>
// 对每个 LogReadResult, 从 fetchInfos 中找到第一个分区相同的 PartitionData, 若找不到分区, 则抛出 RuntimeException
// PartitionData 包含以下字段:
// fetchOffset: Long 要获取的消息 offset
// logStartOffset: Long follower 第一个可用 offset, V5 新增字段
// maxBytes: Long 响应中累积的最大字节数, V3 新增字段
val fetchInfo = fetchInfos.collectFirst {
case (tp, v) if tp == topicPartition => v
}.getOrElse(sys.error(s"Partition $topicPartition not found in fetchInfos"))
// fetchOffsetMetadata: LogOffsetMetadata 来自从本地日志读取的信息
// fetchInfo: PartitionData 来自客户端的请求字段, 利用 FetchContext 得到的
(topicPartition, FetchPartitionStatus(result.info.fetchOffsetMetadata, fetchInfo))
}
// 转发输入参数构造 DelayedFetch 对象
val fetchMetadata = FetchMetadata(fetchMinBytes, fetchMaxBytes, hardMaxBytesLimit, fetchOnlyFromLeader,
fetchOnlyCommitted, isFromFollower, replicaId, fetchPartitionStatus)
val delayedFetch = new DelayedFetch(timeout, fetchMetadata, this, quota, isolationLevel, responseCallback)

// 构造 (topic, partition) 键值对作为延迟 fetch 操作的 key
val delayedFetchKeys = fetchPartitionStatus.map { case (tp, _) => new TopicPartitionOperationKey(tp) }

// 尝试立刻完成请求, 否则将其放入 purgatory, 因为每次创建延迟 fetch 操作时, 新的请求可能到达并使其可完成
delayedFetchPurgatory.tryCompleteElseWatch(delayedFetch, delayedFetchKeys)
}
}
  1. 从本地日志文件中读取得到请求的每个分区的结果(LogReadResult);
  2. 若出现以下错误,则立刻将读取结果构造成 FetchPartitionData 交给回调函数处理;
    • timeout(对应请求的 max_wait_time字段)小于0,即客户端不想等待;
    • 读取结果为空,即客户端请求的任何分区都无法从本地读到结果;
    • 读取字节数不小于 fetchMinBytes(对应请求的 min_bytes 字段);
    • 在读取某个请求的分区的结果时存在错误。
  3. 否则,遍历每个分区的读取结果,和请求中同一分区的请求字段一起构造 FetchPartitionStatus
  4. 构造 DelayedFetch 对象,尝试完成请求,否则将其放入 delayedFetchPurgatory 中延迟处理。

关键的部分就是 readFromLog() 函数和延迟处理的部分。延迟处理相关设施(purgatory,DelayedOperation)在之后去阅读,本篇最后阅读 readFromLog() 和发送响应的回调函数的实现。

responseCallback

KafkaApis.handleFetchRequest 方法中定义的回调函数 processResponseCallback,用来在处理请求完成,构造响应后将响应发送给客户端。

这部分不细看了,因为有不少逻辑是为了实现事务以及配置限额的,这不是目前我阅读源码的重点。核心处理分为两步:

  1. 通过 convertedPartitionDataPartitionData 转换成和兼容旧版本的响应结构;
  2. 调用 KafkaApis.sendResponse 发送响应,在之前的 Produce 请求(2): 发送响应 中都看过这个方法,简单回顾下,实际上就是把响应加入 Processor 的响应队列,之后的发送由 Processor 处理,参考 网络层阅读之 Acceptor 和 Processor 的 4.2 节。

readFromLog

1
2
3
4
5
6
7
8
9
10
11
12
13
def readFromLog(): Seq[(TopicPartition, LogReadResult)] = {
val result = readFromLocalLog(
replicaId = replicaId, // 副本 id, 客户端为 Consumer 则为 -1
fetchOnlyFromLeader = fetchOnlyFromLeader,
readOnlyCommitted = fetchOnlyCommitted,
fetchMaxBytes = fetchMaxBytes, // max_bytes 字段
hardMaxBytesLimit = hardMaxBytesLimit, // 请求版本 >= V3 则为 true, 此时请求有 max_bytes 字段
readPartitionInfo = fetchInfos, // 通过认证且分区存在的分区信息
quota = quota,
isolationLevel = isolationLevel)
if (isFromFollower) updateFollowerLogReadResults(replicaId, result)
else result
}

调用 readFromLocalLog,如果 Fetch 请求来自 follower 则还需要调用 updateFollowerLogReadResults 更新 follower 的结果。

readFromLocalLog

首先看看内部定义的 read 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def read(tp: TopicPartition, fetchInfo: PartitionData, limitBytes: Int, minOneMessage: Boolean): LogReadResult = {
val offset = fetchInfo.fetchOffset
val partitionFetchSize = fetchInfo.maxBytes
val followerLogStartOffset = fetchInfo.logStartOffset

try {
// 决定是否仅从 leader 获取, 然而无论是 Consumer 还是 Follower 都会从 leader 获取
val localReplica = if (fetchOnlyFromLeader)
getLeaderReplicaIfLocal(tp) // 分区的 leader 副本
else
getReplicaOrException(tp)

val initialHighWatermark = localReplica.highWatermark.messageOffset
val lastStableOffset = if (isolationLevel == IsolationLevel.READ_COMMITTED)
Some(localReplica.lastStableOffset.messageOffset)
else
None

// decide whether to only fetch committed data (i.e. messages below high watermark)
val maxOffsetOpt = if (readOnlyCommitted)
Some(lastStableOffset.getOrElse(initialHighWatermark))
else
None

/* 在读取日志之前首先读取 LogOffsetMetadata, 它能判断指定副本是否同步
* 在读取之后再使用 LEO 会导致 race condition, 比如在副本完成消费后, 数据立刻添加到了日志末尾,
* 这可能导致副本一直被判断不同步
*/
val initialLogEndOffset = localReplica.logEndOffset.messageOffset // 在读取操作之前取得 LEO
val initialLogStartOffset = localReplica.logStartOffset
val fetchTimeMs = time.milliseconds // 当前时间戳
val logReadInfo = localReplica.log match {
case Some(log) =>
// 取得 partition_max_bytes (分区本身的最大读取字节数) 和 max_bytes 的较小值作为 fetch 字节数上限
val adjustedFetchSize = math.min(partitionFetchSize, limitBytes)

// 读取 offset 开始的最多 adjustedFetchSize 个字节, 若 minOneMessage 为 true, 则即使第一条消息大小
// 超过了 adjustedFetchSize 也会返回这条消息, 返回类型: FetchDataInfo
val fetch = log.read(offset, adjustedFetchSize, maxOffsetOpt, minOneMessage, isolationLevel)

// 该分区正在被限速, 即限制访问该分区, 清空消息
if (shouldLeaderThrottle(quota, tp, replicaId))
FetchDataInfo(fetch.fetchOffsetMetadata, MemoryRecords.EMPTY)
// V3 版本开始 hardMaxBytesLimit 为 false, 如果第一条消息大小超过了 max_bytes 限制也会读取
// 为了防止客户端报错 RecordToolLargeException, 此时将过大的消息替换成空消息
else if (!hardMaxBytesLimit && fetch.firstEntryIncomplete)
FetchDataInfo(fetch.fetchOffsetMetadata, MemoryRecords.EMPTY)
else fetch

case None => // leader 副本在该分区不存在本地日志
error(s"Leader for partition $tp does not have a local log")
FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY)
}

LogReadResult(info = logReadInfo, // localReplica.log 调用 read 方法的返回值
// localReplica 在内存中维护的 HW, LogStartOffset, LEO
highWatermark = initialHighWatermark,
leaderLogStartOffset = initialLogStartOffset,
leaderLogEndOffset = initialLogEndOffset,
// 请求中 follower 的 LogStartOffset, 客户端为 Consumer 则为 -1
followerLogStartOffset = followerLogStartOffset,
fetchTimeMs = fetchTimeMs, // 从本地读取数据之前记录的时间戳
readSize = partitionFetchSize, // NOTE: 这里是请求的 max_bytes 字段,而非实际读取字节数
lastStableOffset = lastStableOffset, // LSO, 用于事务实现
exception = None)
} catch {
// (...) 异常处理, 返回一个 exception 字段为捕获的异常, 其它字段都不合法的 LogReadResult
}
}

流程:首先取得本地副本(实际上对 Consumer 和 Follower 而言都是 Leader 副本),然后取得 HW,LEO 等字段,记录时间戳,然后通过本地副本读取本地数据。这里还利用了 V3 版本请求的 max_bytes 字段,限制读取的字节数上限,但如果第一条消息长度就超出上限的话,仍然会返回整条消息(此时读取字节数超过了 max_bytes)。

注意 LogReadResult 的第一个字段是从本地日志读取的信息:

1
2
3
4
5
case class FetchDataInfo(fetchOffsetMetadata: LogOffsetMetadata, // offset 元数据, 包括:
// offset; Segment 的基础 offset; 相对于 Segment 的物理偏移字节数
records: Records, // 消息集
firstEntryIncomplete: Boolean = false,
abortedTransactions: Option[List[AbortedTransaction]] = None

主要是前两个字段,消息集就不说了,元数据的作用是记录了 offset 对应消息相对本地 Segment 的实际偏移量。这里回顾一个基本概念,Kafka 的每个分区都用本地文件记录消息,为了防止单个文件过大,会根据文件大小和写入时间分成多个文件,单个文件被称为 Segment(对应代码中的 LogSegment 类),而 Log 类则是管理这些 Segment。因此,记录消息的物理偏移量,便于在从本地 Segment 中快速通过 offset 定位到对应消息。

接着看 readFromLocalLog 的逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
var limitBytes = fetchMaxBytes // 初始值为 max_bytes 字段, 整个响应(多条消息)中可积累的最大字节数
val result = new mutable.ArrayBuffer[(TopicPartition, LogReadResult)]
var minOneMessage = !hardMaxBytesLimit
readPartitionInfo.foreach { case (tp, fetchInfo) =>
val readResult = read(tp, fetchInfo, limitBytes, minOneMessage) // 指定分区的读取结果
val recordBatchSize = readResult.info.records.sizeInBytes // 实际读取字节数
// 读取了至少一条消息, 那么以后严格遵守 max_bytes 的限制
if (recordBatchSize > 0)
minOneMessage = false
limitBytes = math.max(0, limitBytes - recordBatchSize)
result += (tp -> readResult)
}
result

可见,每个分区都对应一条读取结果(LogReadResult),包含 offset 对应消息,还有 HW/LEO 等信息 。V3 开始外部的 max_bytes 字段限制所有消息的最大字节数,而每个分区都有自己的 partition_max_bytes 限制单条消息的最大字节数。

读完这部分代码后,可以回顾 Fetch 请求的协议(V3 版本),并附上注释说明:

1
2
3
4
5
6
7
8
9
10
11
Fetch Request (Version: 3) => replica_id max_wait_time min_bytes max_bytes [topics] 
replica_id => INT32 // -1: Consumer, >= 0: Follower
max_wait_time => INT32 // 延迟请求中的 timeout,用于构造 DelayedFetch 对象
min_bytes => INT32 // 响应字节数超过则立刻发送响应,见ReplicaManager.fetchMessages
max_bytes => INT32 // 整个响应的最大字节数
topics => topic [partitions]
topic => STRING
partitions => partition fetch_offset partition_max_bytes
partition => INT32
fetch_offset => INT64
partition_max_bytes => INT32 // 每个分区消息的最大字节数

其中 fetch_offset 可由 FetchContext 的相关方法取得:

1
2
3
4
5
trait FetchContext extends Logging {
/**
* Get the fetch offset for a given partition.
*/
def getFetchOffset(part: TopicPartition): Option[Long]

updateFollowerLogReadResults

当 replica id 大于 0 时,代表客户端为 Follower,在从本地日志读取信息后,会调用该方法更新 Follower 的 fetch 状态,并更新读取结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
private def updateFollowerLogReadResults(replicaId: Int,
readResults: Seq[(TopicPartition, LogReadResult)]): Seq[(TopicPartition, LogReadResult)] = {
readResults.map { case (topicPartition, readResult) =>
var updatedReadResult = readResult
nonOfflinePartition(topicPartition) match {
case Some(partition) =>
partition.getReplica(replicaId) match {
case Some(replica) =>
case Some(replica) =>
// 首先更新分区上的 follower 的状态, 若 LW 或 HW 增加则返回 true
if (partition.updateReplicaLogReadResult(replica, readResult))
// 将 leader 副本的信息 (HW, LogStartOffset, LEO) 更新到读取结果上
partition.leaderReplicaIfLocal.foreach { leaderReplica =>
updatedReadResult = readResult.updateLeaderReplicaInfo(leaderReplica)
}
case None => // 当前副本不是分区的副本, 清空读取结果的 records 字段并将 metadata 标记为未知
// 略去日志...
updatedReadResult = readResult.withEmptyFetchInfo
}
case None => // 分区不可用(即 offline 分区), 打印警告日志, 不修改读取结果
warn(s"While recording the replica LEO, the partition $topicPartition hasn't been created.")
}
topicPartition -> updatedReadResult
}
}

然后对读取结果调用 updateLeaderReplicaInfo 更新为 leader 副本的信息:

1
2
3
4
def updateLeaderReplicaInfo(leaderReplica: Replica): LogReadResult =
copy(highWatermark = leaderReplica.highWatermark.messageOffset,
leaderLogStartOffset = leaderReplica.logStartOffset,
leaderLogEndOffset = leaderReplica.logEndOffset.messageOffset)

利用 Scala case 类的 copy 方法,返回更新对应字段后的对象。这里将读取结果的 HW,LogStartOffset,LEO 更新为 leader 副本维护的相应信息。因为 follower 副本发送 Fetch 请求时,leader 副本可能更新 HW(如果之前 follower 没有同步到最新),因此需要把更新后的 HW 发送给 follower。

顺带提下这里涉及到的另一个概念:低水位(LW, Low Watermark)。LW 即所有副本中最小的 LogStartOffset,一般情况下 LW 都是 0,但是如果服务端收到了 DeleteRecords 请求,删除日志文件的部分记录(消息)时,会更新 LW。

总结

本篇阅读了 Fetch 请求的处理流程,主要根据 replica id 字段分 Consumer 和 Follower 来处理:

  1. 会话认证,判断请求分区是否存在,将没有问题的分区对应的请求构成 Map 由 ReplicaManager 处理;
  2. ReplicaManager 对每个分区,找到其 leader 副本;
  3. leader 副本从本地读取请求的 offset 开始的若干消息(由全局的以及各分区的 max_bytes 字段来限制读取最大字节数),和维护的其它信息构成读取结果;
  4. 对 follower 副本的请求,还会将 leader 副本的 HW,LEO,LogStartOffset 更新到读取结果中;
  5. 根据读取结果和请求的相关字段判断是否立刻发送响应,比如读取没问题时,读取字节数超过了 min_bytes 即可发送;
  6. 否则,构造 DelayedFetch 对象传入 DelayedFetchPurgatory 对象中,此时 purgatory 还会判断一次处理是否完成,若已完成则不用延迟处理。

主要区别还是第 4 步,因为 follower 的 Fetch 请求是用来与 leader 同步的,因此需要将 HW 记录在结果中让 follower 更新自己的 HW。

Scala 笔记 - 函数式编程风格初探

前言

最近一段时间写 Scala 比较多,虽然用传统的风格写代码也没问题,但 Scala 既然提供了比较方便的函数式编程方式,那么还是入乡随俗,好好利用比较好。

目前写下来几个感受最深的还是:尽量避免使用 varmutable 数据结构。

上周末抽空看了下 Java 并发编程实战,基于锁和原子变量的线程安全实现虽然用起来很方便,但是一旦逻辑复杂了,锁住哪些变量,粒度多大,都会使代码变得比较复杂。而使用不可变(immutable)数据结构则能很大程度简化线程安全的实现(当然,锁和原子变量一定程度上还是需要的),因为不可变数据结构本身是线程安全的。

此外,Scala 不提供原生的 breakcontinue 来进行流程控制,虽然可以通过导入模块的方式使用,但尽量还是避免。

不过作为 C++er,有些地方从 C++ 转过来还是不太适应,因此记录一些笔记。

编程习惯的改变

Scala 使用 () 来使用下标访问集合,而其它大多数语言都是使用 [],比如用 C 风格代码遍历数组:

1
2
3
val array = Array(1, 2, 3)
for (i <- array.indices)
println(array(i)) // array(i) 而非 array[i]

原因很简单,因为 Scala 使用 [] 来指定泛型,而其它语言比如 C++ 和 Java 都是使用 <>,Scala 的 <> 都有各自用法:

Scala 有比其它大部分语言更为强大的 for 循环,比如:

1
2
3
val map = Map(1 -> "Java", 2 -> "C", 3 -> "Python", 4 -> "C++")
for ((key, value) <- map if value.matches("C.*"))
println(s"$key => $value")

打印出 value 以 C 开头的键值对,当然,Scala 也可以使用传统函数式编程常用的 mapfilter 等方法构成调用链:

1
2
3
map.filter(_._2.matches("C.*")).foreach { case (key, value) =>
println(s"$key => $value")
}

注意这里的 _2 是取得元组(前面说 -> 提到了二元元组)的第 2 个元素。两者是基于 Scala 强大的模式匹配(match)的,并且都用语法糖省略了多余代码。

这里由于仅仅是处理每个元素,因此用了 foreach,如果需要得到结果,比如将 println 打印的字符串构成数组或链表,就可以用 map,然后用 mkString 将换行符作为分隔符即可实现一样的功能:

1
2
val list = map.filter(_._2.matches("C.*")).map { case (key, value) => s"$key => $value" }
println(list.mkString("\n"))

因此 map 可谓最常用的方法了,前面说了,多线程访问共享的可变数据结构时存在 race condition,而如果使用 map 将共享的不可变数据结构映射为线程内部可见的另一个不可变数据结构时,则避免了 race condition,也不用麻烦地去加锁来解决,加锁还要考虑锁的粒度,还要谨慎思考粒度太小会不会导致线程不安全了。

当然,作为 C++er 的一个坏习惯就是过早优化,比如担心拷贝和内存多次分配的开销会不会太大。

然而实践起来,很多时候性能的瓶颈是网络延迟,磁盘 I/O,算法处理速度这些因素,而不是拷贝和内存分配。有些拷贝也是必要的,用 C++ 写,不加锁要实现线程安全,也得拷贝一份,至于内存分配,JVM 的 GC 经过一代代发展已经相当强大了。

这里不是说性能不重要,而是说,性能导致问题之前,编写不易出 BUG 的代码优先级更高。

流程控制

比如,从 N 台副本服务器上取数据,只要成功取得一份就返回,因此不需要并行操作,而是循环依次去访问服务器,这里给几个模拟实际场景的变量和方法:

1
2
3
4
5
6
7
val servers = Seq("Server1", "Server2", "Server3")

type DataType = String

def getDataFromServer(server: String): DataType = {
if (server.contains("2")) "Ok" else ""
}

return vs break

使用 return 单独写个方法:

1
2
3
4
5
6
7
8
def getData: DataType = {
for (server <- servers) {
val data = getDataFromServer(server)
if (data.nonEmpty)
return data
}
""
}

这里的 for (server <- servers) { 也可以换成 servers.foreach { server =>,看个人喜好。

当然也可以用 break,首先 import scala.util.control.Breaks,然后代码改成:

1
2
3
4
5
6
7
8
var data: DataType = ""
breakable {
servers.foreach { server =>
data = getDataFromServer(server)
if (data.nonEmpty)
break
}
}

这种方式仅仅是举例,实际写代码应该避免。除了少写一个函数外,没任何好处。

  1. 是破坏了不用var 的原则,虽然在这里并没有什么影响,但很容易让人逐渐依赖于 var;

  2. Scala 使用 break 本来就比其它语言的 break 要复杂,多了外层的 breakable 块;

  3. 底层是用异常来实现的:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    private class BreakControl extends ControlThrowable

    class Breaks {
    private val breakException = new BreakControl

    def breakable(op: => Unit) {
    try {
    op
    } catch {
    case ex: BreakControl =>
    if (ex ne breakException) throw ex
    }
    }

    def break(): Nothing = { throw breakException }
    }

    也就是说,如果在 breakable 块中处理异常时,还得额外捕获 ControlThrowable

    1
    2
    3
    4
    5
    6
    7
    8
    breakable {
    try {
    // 调用了 break 的代码(略)
    } catch {
    case e: ControlThrowable => throw e // 抛给外层来流程控制
    case e: Throwable => // TODO: 处理真正的异常
    }
    }

总之,break 仅仅是让习惯了其它语言的用户方便上手而已。

尾递归

之前 foreach 的做法还是避免不了 return 跳出循环,实际上 Scala 提供了尾递归优化,这里不详述,简单说就是将特殊条件下函数的递归调用优化替换成循环调用,并且无法优化的场景利用 tailrec 注解会抛出异常说明此处无法尾递归优化。

1
2
3
4
5
6
7
8
9
10
11
12
@scala.annotation.tailrec
def getData(servers: Seq[String]): DataType = {
servers match {
case Seq() => ""
case Seq(server, rest@_*) =>
val data = getDataFromServer(server)
if (data.nonEmpty)
data
else
getData(rest.toSeq)
}
}

上述代码利用了 Scala 强大的模式匹配能力。

第一行的 case Seq() 匹配空 Seq,也就是递归终止条件。

第二行的需要说明的是 rest@_*,进行匹配的是 _*_ 匹配类型,而 * 在 Scala 中则是匹配可变参数列表。前面在通过 @ 运算符将匹配到的可变参数列表绑定到变量 rest 上。

这里的主要好处还是略去了 return,在其它语言中 return 的一个好处是提前返回来避免 N 层缩进的难以阅读的代码。不过另一方面,return 的滥用会导致程序流程不是那么清晰,因为代码太长的话不知道前面哪里会直接 return了。不过其实说起来,这里的 return 是用在单独的方法中,相当于被隐藏了,也不会导致跳出方法外层的循环,可读性并不受影响。

一开始我是认为尾递归的方式比 return 更好,当时最大原因错以为 return 会抛出 ControlThrowable 来进行流程控制导致外层的 try-catch 需要额外处理这个异常,后来发现并不会抛出。因此在这里使用尾递归+模式匹配某种程度上有点炫技的意味。

另外,如果 servers 类型是 List 则可以用这种尾递归方式:

1
2
3
4
5
6
7
8
9
10
11
12
@scala.annotation.tailrec
def getData(servers: List[String]): DataType = {
servers match {
case Nil => ""
case server :: rest =>
val data = getDataFromServer(server)
if (data.nonEmpty)
data
else
getData(rest) // tail recursion
}
}

Future 的处理

Scala 可以用 Java 的线程设施来编写多线程程序,但是内置的 Future 一般情况下够用和好用了,最近用到的 play framework WS 模块PostGet 方法返回的都是 Future

一般情况下 FutureonComplete 方法,利用回调函数来处理正常返回或者异常发送的场景:

1
2
3
4
5
6
7
8
9
10
11
val s = "hello"
val future = Future {
if (s.length % 2 == 0)
s.length
else
throw new Exception(s""""$s"'s length is not even""")
}
future.onComplete {
case Failure(e) => // TODO: 处理 future 抛出的异常
case Success(result) => // TODO: 处理 future 的结果,这里即字符串长度
}

注意 Future 底层还是使用线程池的,因此需要导入 ExecutionContext,一般用默认的就行:

1
import scala.concurrent.ExecutionContext.Implicits.global

但是有些时候还是需要等待 Future 完成的,这里不回顾 Future 的语法,而是谈一些同步方面的处理。

等待多个 Future 结束

比如最常见的分块计算,然后将结果汇总:

1
2
3
4
5
6
7
8
9
val array = (1 to 10).toArray
val blockSize = 3

val blocks = (array.indices by blockSize).map { i => (i, i + blockSize) }.map { pair =>
if (pair._2 <= array.length) pair else pair._1 -> array.length
}
val futures = blocks.map { case (from, to) =>
Future(array.slice(from, to).sum)
} // Seq[Future[Int]]

现在得到了 Future[Int] 的序列,可以用回调函数的方法将结果存入一个 ConcurrentHashMap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 修改了 Future 的定义,加入了 from 作为 key
val futures = blocks.map { case (from, to) =>
Future(from -> array.slice(from, to).sum)
}
val results = new ConcurrentHashMap[Int, Int]()
futures.foreach { future =>
future.onComplete {
case Failure(e) => // TODO: 处理单个计算错误的异常
case Success((from, sum)) => results.put(from, sum)
}
}
Thread.sleep(100) // 模拟主线程做其它操作
while (futures.size != blocks.size) { // 无事可做了,轮询等待
Thread.sleep(100)
}
val sum = results.entrySet().asScala.map(_.getValue).sum

注意 asScala 将 Java 的集合类型转换成 Scala 对应的集合类型,在 Scala 2.13 之前需要:

1
import scala.collection.JavaConverters._

从 Scala 2.13 开始则是导入另一个包:

1
import scala.jdk.CollectionConverters._

PS:本文默认都是 Scala 2.8 以上。

回正题,如果主线程不需要做其他操作,就只想等待,那么这种基于回调的方式就未免过于麻烦,不如直接等待。但是 Await.result 或者 Await.ready 只能等待单个 Future,于是得 for 循环等待,然后还是得将结果一个个存入 HashMap 或者其它容器中。

Future 类提供了 sequence() 方法来简化这个操作,直接一行搞定:

1
val sum = Await.result(Future.sequence(futures), 2.seconds).sum

注意 seconds 需要导入相关包:

1
import scala.concurrent.duration._

sequence 方法将 Future[A] 的集合转换成 A 的集合的 Future,在这里即将Seq[Future[Int]] 转换成了 Future[Seq[Int]],这样直接等待就行了。

任务的顺序执行

有时候一个任务需要等待另一个任务完成后才能执行,此时可以用 FuturemapflatMap 方法:

1
2
3
4
val f1 = Future { 1 }
val f2 = f1.map { _ * 2}
val f3 = f2.map { _ * 3}
val result = Await.result(f3, 10.milliseconds)

map 方法接收的 block 是将结果类型映射到结果类型,但是 flatMap 方法是将结果类型映射到结果类型的 Future,有时候外部方法返回的是 Future 类型,此时就得用 flatMap

1
2
3
4
5
6
7
8
9
// 模拟外部的接口,比如 PlayWS
def getResponse(x: Int) = Future {
x * 2
}

val f1 = Future { 1 }
val f2 = f1.flatMap(getResponse)
val result = Await.result(f2, 10.milliseconds)
println(result)

假如任务数量不确定,也就回到前文提到的类型,多个服务器,取到一个就退出:

1
2
3
4
5
val servers = List("Server1", "Server2", "Server3")

def getResponseFromServer(server: String) = Future {
if (server.contains("2")) "Ok" else ""
}

这里可以用类似前文尾递归的方法得到新 Future

1
2
3
4
5
6
7
8
9
10
11
12
def getResponse(servers: List[String]): Future[String] = servers match {
case Nil => Future { "" }
case server :: rest =>
getResponseFromServer(server).flatMap { result =>
if (result.nonEmpty)
Future { result }
else
getResponse(rest)
}
}

val result = Await.result(getResponse(servers), 10.milliseconds)

这里无法用尾递归,因为递归调用发生在 flatMap 接收的 block 中,而非当前方法末尾。

如果把该方法做成同步调用就行了,因为反正要等待,不如直接每次 Future 都等待:

1
2
3
4
5
6
7
8
9
10
11
12
@scala.annotation.tailrec
def getResponse(servers: List[String]): String = servers match {
case Nil => ""
case server :: rest =>
val result = Await.result(getResponseFromServer(server), 10.milliseconds)
if (result.nonEmpty)
result
else
getResponse(rest)
}

val result = getResponse(servers)

这也是我在项目里实际采用的做法,这种做法有个缺点就是不便于扩展,如果返回 Future,那么如果以后要用到 getResponse 的结果,直接 mapflatMap 即可,但是现在的话,就必须同步等待 getResponse 完成了。

但是考虑到基本上没有进一步扩展的需求,目前就保持这样了。

总结

本文算是最近写 Scala 时的一些笔记,其实学 Scala 主要是为了看 Kafka 源码,Kafka 的 Scala 代码其实很多还是并不那么函数式的,毕竟很大一块基础设施还是 Java 写的然后 Scala 来调用,当然,不否认不少地方也用了函数式编程来节省代码和增加可读性。

比如在 1.1.0 版本的 ReplicaManager.scala 中,有一段代码:

1
2
val errorReadingData = logReadResultValues.foldLeft(false) ((errorIncurred, readResult) =>
errorIncurred || (readResult.error != Errors.NONE))

foldLeft 之前只用过一遍,所以看到这里根本不知道什么意思,实际上等价于

1
2
3
4
var errorReadingData = false
logReadResultValues.foreach { readResult =>
errorReadingData = (errorReadingData || (readResult.error != Errors.NONE))
}

也就是 logReadResultValues 中存在一个元素的 error 字段为 Errors.NONE 则为 false。当然,其实存在一个就可以 break 或 return 了,但后面继续循环也不会有什么显著性开销,所以用 foldLeft 非常简洁有效地实现了功能。重要的还是没有使用 var

Kafka源码阅读08: 写入本地日志的具体实现

回顾

06: Produce请求之写入本地日志中,对ReplicaManager类的appendToLocalLog方法的阅读,主要集中在对异常场景的处理:

  • 非admin客户端写入__consumer_offsets等特殊主题;
  • 找不到请求的主题+分区;
  • 请求的是离线分区;
  • 当前broker不是请求分区的leader;
  • 请求的acks字段不合法,或者为-1(all)但ISR数量小于min.insync.replicas配置。

会抛出异常被捕获后生成LogAppendResult对象(见server/ReplicaManager.scala)

1
2
3
4
5
6
case class LogAppendResult(info: LogAppendInfo, exception: Option[Throwable] = None) {
def error: Errors = exception match {
case None => Errors.NONE
case Some(e) => Errors.forException(e)
}
}

对上述异常场景,LogAppendResult.info被置为无效值:

1
2
3
4
object LogAppendInfo {
val UnknownLogAppendInfo = LogAppendInfo(-1, -1, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, -1L,
RecordsProcessingStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1, offsetsMonotonic = false)
}

appendToLocalLog返回的LogAppendResult07: Produce请求之发送响应 中会用来生成PartitionResponse对象和对应主题分区构成Map传给发送响应给客户端的回调函数中。

也就是说,最关键的部分我们之前暂且跳过了,也就是在正常清空下如何写入本地日志文件,然后生成LogAppendInfo

Log.append代码分析

cluster包的Partition.scala中,将当前分区的leaderEpoch字段传入了appendAsLeader

1
val info = log.appendAsLeader(records, leaderEpoch = this.leaderEpoch, isFromClient)

logLog对象,位于log包下的Log.scala。该方法会调用append

1
2
3
def appendAsLeader(records: MemoryRecords, leaderEpoch: Int, isFromClient: Boolean = true): LogAppendInfo = {
append(records, isFromClient, assignOffsets = true, leaderEpoch)
}

这里只考虑来自客户端的请求,因此接下来阅读时默认isFromClientassignOffsets为true

1
2
3
4
5
private def append(records: MemoryRecords, isFromClient: Boolean, assignOffsets: Boolean, leaderEpoch: Int): LogAppendInfo = {
maybeHandleIOException(s"Error while appending records to $topicPartition in dir ${dir.getParent}") {
// ...
}
}
1
2
3
4
5
6
7
8
9
private def maybeHandleIOException[T](msg: => String)(fun: => T): T = {
try {
fun
} catch {
case e: IOException =>
logDirFailureChannel.maybeAddOfflineLogDir(dir.getParent, msg, e)
throw new KafkaStorageException(msg, e)
}
}

maybeHandleIOException捕获fun可能抛出的IOException,进一步抛出KafkaStorageException会被上层捕获生成LogAppendResult

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  // 校验消息的CRC以及消息长度(字节数)是否合法(不超过配置的 max.message.bytes), 并且会设置以下字段:
// - firstOffset: 第1条消息的offset, V2版本可以从header的firstOffset字段直接取得
// - lastOffset: 最后1条消息的offset, V2版本可以从header的firstOffset + lastOffsetDelta得到
// - shallowCount: 消息集的数量,shallow即浅层,以消息集为单位
// - validBytes: 所有长度合法的消息的长度之和
// - offsetsMonotic: 消息offset是否单调递增,利用每个消息集的lastOffset判断
// - sourceCodec: 生产者消息集的编码方式
val appendInfo = analyzeAndValidateRecords(records, isFromClient = isFromClient)

if (appendInfo.shallowCount == 0) // 没有合法消息则直接返回
return appendInfo

// 截断records中不合法的字节数, 然而按照前面的逻辑, 如果 analyzeAndValidateRecords 不抛出异常,
// appendInfo.validBytes 和 records.sizeInBytes 是相等的, 猜测是遗留方法?
var validRecords = trimInvalidBytes(records, appendInfo)

// 将 validRecords 插入到日志中, 由于可能其他处理线程也在将消息写入本地文件, 所以用锁保护
lock synchronized {
// 检查内存映射的缓冲区是否关闭, 比如在 closeHandlers() 会导致其关闭
// 若关闭, 则表示无法写入, 抛出 KafkaStorageException
checkIfMemoryMappedBufferClosed()
if (assignOffsets) {
// assign offsets to the message set
// 生产者发送的消息集的offset为0,1,...,n, nextOffsetMetadata记录了本地日志
// 下一条消息的offset, 将其置为新的firstOffset, 也就是绝对offset
val offset = new LongRef(nextOffsetMetadata.messageOffset)
appendInfo.firstOffset = offset.value
val now = time.milliseconds // 当前时间戳, 即LogAppendTime类型的时间戳
// 重新验证/解压/压缩得到更新内部offset后的validRecords
val validateAndOffsetAssignResult = try {
// 更新消息集的offset, 对于V1版本以上的消息, 可能因为时间戳类型字段来覆盖时间戳
LogValidator.validateMessagesAndAssignOffsets(validRecords,
offset, // 会更新为最后1条消息的绝对offset+1, 即下一次写入本地日志的消息的offset
time,
now,
appendInfo.sourceCodec,
appendInfo.targetCodec,
config.compact,
config.messageFormatVersion.messageFormatVersion.value,
config.messageTimestampType,
config.messageTimestampDifferenceMaxMs,
leaderEpoch,
isFromClient)
} catch {
case e: IOException => throw new KafkaException("Error in validating messages while appending to log '%s'".format(name), e)
}
validRecords = validateAndOffsetAssignResult.validatedRecords
// 设置 appendInfo 的以下字段:
// - maxTimestamp: 消息集的最大时间戳
// - offsetOfMaxTimestamp: 最大时间戳对应消息的绝对offset
// - lastOffset: 最后1条消息的offset
// - logAppendTime: 如果时间戳类型为LOG_APPEND_TIME, 则设为刚刚获取的时间戳
appendInfo.maxTimestamp = validateAndOffsetAssignResult.maxTimestamp
appendInfo.offsetOfMaxTimestamp = validateAndOffsetAssignResult.shallowOffsetOfMaxTimestamp
appendInfo.lastOffset = offset.value - 1
appendInfo.recordsProcessingStats = validateAndOffsetAssignResult.recordsProcessingStats
if (config.messageTimestampType == TimestampType.LOG_APPEND_TIME)
appendInfo.logAppendTime = now

// 重新验证消息大小是否超过max.message.size, 因为修改字段后重新压缩可能导致消息大小改变
if (validateAndOffsetAssignResult.messageSizeMaybeChanged) {
for (batch <- validRecords.batches.asScala) {
if (batch.sizeInBytes > config.maxMessageSize) {
// 更新stats(略)
throw new RecordTooLargeException("Message batch size is %d bytes which exceeds the maximum configured size of %d."
.format(batch.sizeInBytes, config.maxMessageSize))
}
}
}
} else {
// assignOffsets为false的处理(略)
}

// TODO: 对V2以上版本的消息集, 更新 leader epoch cache
validRecords.batches.asScala.foreach { batch =>
if (batch.magic >= RecordBatch.MAGIC_VALUE_V2)
_leaderEpochCache.assign(batch.partitionLeaderEpoch, batch.baseOffset)
}

// 检查消息集的总大小是否超过配置的segment.bytes, 即每个日志文件的大小
if (validRecords.sizeInBytes > config.segmentSize) {
throw new RecordBatchTooLargeException("Message batch size is %d bytes which exceeds the maximum configured segment size of %d."
.format(validRecords.sizeInBytes, config.segmentSize))
}

// now that we have valid records, offsets assigned, and timestamps updated, we need to
// validate the idempotent/transactional state of the producers and collect some metadata
// TODO: 验证生产者的 幂等性/事务 状态, 并收集一些元数据
val (updatedProducers, completedTxns, maybeDuplicate) = analyzeAndValidateProducerState(validRecords, isFromClient)
maybeDuplicate.foreach { duplicate =>
appendInfo.firstOffset = duplicate.firstOffset
appendInfo.lastOffset = duplicate.lastOffset
appendInfo.logAppendTime = duplicate.timestamp
appendInfo.logStartOffset = logStartOffset
return appendInfo
}

// 如果必要, 执行日志回滚策略, 将当前日志文件备份, 并创建空文件作为当前日志文件
val segment = maybeRoll(messagesSize = validRecords.sizeInBytes,
maxTimestampInMessages = appendInfo.maxTimestamp,
maxOffsetInMessages = appendInfo.lastOffset)

val logOffsetMetadata = LogOffsetMetadata(
messageOffset = appendInfo.firstOffset,
segmentBaseOffset = segment.baseOffset,
relativePositionInSegment = segment.size)

segment.append(firstOffset = appendInfo.firstOffset,
largestOffset = appendInfo.lastOffset,
largestTimestamp = appendInfo.maxTimestamp,
shallowOffsetOfMaxTimestamp = appendInfo.offsetOfMaxTimestamp,
records = validRecords)

// 更新生产者状态
for ((producerId, producerAppendInfo) <- updatedProducers) {
producerAppendInfo.maybeCacheTxnFirstOffsetMetadata(logOffsetMetadata)
producerStateManager.update(producerAppendInfo)
}

// update the transaction index with the true last stable offset. The last offset visible
// to consumers using READ_COMMITTED will be limited by this value and the high watermark.
// TODO: 用最新的稳定offset更新事务
for (completedTxn <- completedTxns) {
val lastStableOffset = producerStateManager.completeTxn(completedTxn)
segment.updateTxnIndex(completedTxn, lastStableOffset)
}

producerStateManager.updateMapEndOffset(appendInfo.lastOffset + 1)

// 更新 nextOffsetMetadata 为 lastOffset+1, 回顾之前if (assignOffsets)分支
// 在下一批消息到达时, 该offset即新消息集的第1个消息的绝对offset
updateLogEndOffset(appendInfo.lastOffset + 1)

// TODO: update the first unstable offset (which is used to compute LSO)
updateFirstUnstableOffset()

// trace日志(略)

// 若未冲刷的消息数量(利用LEO减去recoverPoint得到)达到了配置的"flush.messages"则冲刷消息
if (unflushedMessages >= config.flushInterval)
flush()

appendInfo
}
}

注释中标出TODO的部分暂时还不了解原理,包括且不限于:

  • leader epoch;
  • 对事务的支持;
  • stable offset(似乎也是用于事务?)

流程总结

首先是代码逻辑的大体流程:

  1. records.batches为一组Record Batch,对每个batch都校验CRC是否合法,字节数是否超过配置max.message.bytes
  2. 若存在不合法的batch,则会抛出异常最终被ReplicaManager.appendToLocalLog捕获(仅限于Produce请求处理的情况),生成包含错误的响应;
  3. 利用records计算出第1条消息和最后1条消息的offset,消息集的数量,合法batch的字节数之和,消息offset是否单调递增,以及消息集的编码方式,构造要返回的LogAppendInfo对象,记为info
  4. 验证合法消息的数量,并截断不合法的字节数,得到validRecords;(TODO:此处实现似乎不合理,因为存在不合法的batch直接就抛异常了,但当前最新版本2.3的Kafka源码也是这么处理的
  5. 检测内存映射缓存是否被关闭;
  6. 将LEO赋给info.firstOffset,并取得当前时间戳now
  7. 更新validRecords的offset为绝对offset,若batch是压缩的则重新压缩,将最后1条消息的offset赋给info.lastOffset,并设置info的消息集最大时间戳及对应消息的offset;
  8. 若时间戳类型为LOG_APPEND_TIME,将now赋给info.logAppendTime(默认为-1);
  9. 若重新压缩的validRecords字节数发生变化,重新检查每个batch的字节数是否超过配置max.message.bytes
  10. 检查validRecords字节数是否超过配置log.segments.bytes
  11. 取得当前的LogSegment对象,将validRecords添加进去;
  12. 更新LEO为validRecords最后1条消息的offset+1;
  13. 若未冲刷的消息数量超过了配置flush.messages,则将所有LogSegments写入本地磁盘。

核心还是用绝对offset替换相对offset。生产者向服务端发送请求时,由于不知道消息集落盘时的offset,所以只能设置offsets为0,1,2,…n-1,也就是相对offset。而分区的leader broker则维护了其LEO,因此收到请求时,会将offsets修改为LEO,LEO+1,LEO+2,…LEO+n-1,最后将LEO更新为LEO+n。而更新的offsets会包含在响应里发送给生产者,这样客户端就可以通过消息送达的回调函数得到发送消息的绝对offset。

每个Log对应1个分区的消息日志,而消息日志是分为多个文件(日志片段,Log Segment)对应LogSegment对象,负责实际写入磁盘。

这里回顾用到的3个Kafka服务端配置:

  • max.message.bytes:每个消息集的最大字节数(这是0.11开始的含义,见upgrade 0.11 message format
  • log.segment.bytes:Log Segment的最大字节数(所以需要检测消息集字节数是否超过这个值,否则即使新建文件写入消息集,也无法容纳整个消息集);
  • flush.messages(Topic级别):允许LogSegment对象缓存的消息数量,如果缓存消息数超过了该配置就会调用fsync写入磁盘。

此外,Record Batch消息集(Message Set),Record(记录)即Message(消息)。之所以这里区分,是因为从Kafka 0.11版本开始,消息集的概念发生了变化。在此之前,消息集仅仅是一组消息之前加上Log Overhead(即offset和message size)。而Kafka 0.11版本新增了,V2版本的消息集,增加了独有的header,比如第1条消息和最后1条消息的offset可直接通过header计算得到,还有些其他字段是leader epoch以及实现事务相关的字段。而每条消息(记录)的key和value用varint而非固定4字节的int表示长度,并且消息本身也有header。

具体参考:https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-Messagesets

Kafka源码阅读07: Produce请求(2): 发送响应

前文回顾

前一篇阅读了appendToLocalLog的部分,服务端在收到Produce请求时,会首先将消息写入本地消息日志:

1
2
3
4
5
val sTime = time.milliseconds // 取得当前毫秒级时间戳
val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
isFromClient = isFromClient, entriesPerPartition, requiredAcks)
// 调试信息: 再次取得时间戳, 相减得到 appendToLocalLog 的用时
debug("Produce to local log in %d ms".format(time.milliseconds - sTime))

返回的localProduceResults类型是Map[TopicPartition, LogAppendResult]

1
2
3
4
5
6
7
8
// 正常结果: info有效,exception为None
// 错误结果: info各字段设为无效值,exception为某种异常,可通过error方法取得其错误信息
case class LogAppendResult(info: LogAppendInfo, exception: Option[Throwable] = None) {
def error: Errors = exception match {
case None => Errors.NONE
case Some(e) => Errors.forException(e)
}
}

info字段来自于Log.appendAsLeader的返回值,即实际添加到本地日志的消息,包含消息集的第1条消息和最后1条消息的offset(生产者在发送消息集时是不知道最后写入日志文件时消息的offset,只有在服务端写入日志时才会设置)。

接下来阅读ReplicaManager.appendRecords中的后续处理。

ProducePartitionStatus的处理

1
2
3
4
5
6
7
8
9
10
11
12
// 将分区对应的处理结果转换成 ProducePartitionStatus 对象
val produceStatus = localProduceResults.map { case (topicPartition, result) =>
topicPartition ->
ProducePartitionStatus(
// lastOffset + 1 代表下一批消息的第1条消息的 offset
result.info.lastOffset + 1, // required offset
// 利用 LogAppendInfo 的各字段构造 PartitionResponse
new PartitionResponse(result.error, result.info.firstOffset, result.info.logAppendTime, result.info.logStartOffset)) // response status
}

// 通过 KafkaApis.handleProduceRequest 传入的回调更新 KafkaApis.brokerTopicStats
processingStatsCallback(localProduceResults.mapValues(_.info.recordsProcessingStats))
1
2
3
4
5
6
7
public static final class PartitionResponse {
public Errors error; // 错误信息
public long baseOffset; // 消息集中第1条消息的offset
public long logAppendTime; // 消息集写入日志文件时的时间戳
public long logStartOffset; // 消息集写入日志文件时,日志文件的起始offset
// ...
}

回顾一下,在使用Kafka客户端时,生产者可以通过回调取得消息的元数据,像主题和分区,是在生产者发送前就已知的,但offset和时间戳则是由服务端在此处设置的。见Kafka 1.1 Producer APIRecordMetadata

接下来是一个if-else分支

1
2
3
4
5
6
7
8
if (delayedProduceRequestRequired(requiredAcks, entriesPerPartition, localProduceResults)) {
// ...
} else {
// we can respond immediately
// 取得 PartitionStatus 作为新的value传进 responseCallback, 即忽略了 offset 字段
val produceResponseStatus = produceStatus.mapValues(status => status.responseStatus)
responseCallback(produceResponseStatus)
}

如果delayedProduceRequestRequired返回false,则可以立刻发送响应,而且忽略了offset字段,因为该字段代表了下一批消息的第1个offset,而PartitionStatus本身就包含当前消息集的baseOffset

那么为何else分支就意味着可以立刻发送响应呢?

1
2
3
4
5
6
7
8
private def delayedProduceRequestRequired(requiredAcks: Short,
entriesPerPartition: Map[TopicPartition, MemoryRecords],
localProduceResults: Map[TopicPartition, LogAppendResult]): Boolean = {
requiredAcks == -1 &&
entriesPerPartition.nonEmpty &&
// exception字段为Option类型,若不为None则isDefined返回true
localProduceResults.values.count(_.exception.isDefined) < entriesPerPartition.size
}

可见,if分支意味着以下条件满足:

  1. requiredAcks为-1,即生产者要等待分区的所有ISR收到消息后才会返回;
  2. entriesPerPartition不为空,即存在需要添加消息的分区;
  3. localProduceResults中至少存在1条成功的结果。

相应地,else分支对应的是:

  1. requiredAcks为0或1,即客户端无需等待服务端的响应或者只需要等待leader收到消息;
  2. 没有消息需要写入(无论是没有可写入的分区还是全部消息写入出现异常),那么ISR也没必要去从leader复制数据,因此也可以直接返回响应。

PS:第2个条件在处理Produce请求是是多余的判断,因为之前在KafkaApis.handleProduceRequest中已经判断过了:

1
2
3
4
5
6
7
8
9
10
if (authorizedRequestInfo.isEmpty)
sendResponseCallback(Map.empty)
else { // authorizedRequestInfo非空, 传入参数entriesPerPartition
// ...
replicaManager.appendRecords(
// ...
entriesPerPartition = authorizedRequestInfo,
/* ... */)
// ...
}

也就是说if分支里会等待所有ISR收到消息才会返回,查看if分支:

1
2
3
4
5
6
7
8
9
10
// 构造 DelayedProduce 对象, 注意 timeout 仅在此处使用
val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)

// 通过 topic 和 partition 创建用于延迟生产操作的key
val producerRequestKeys = entriesPerPartition.keys.map(new TopicPartitionOperationKey(_)).toSeq

// 尝试立刻完成请求, 否则会将请求放入 purgatory 中, 因为在创建 DelayedProduce 对象时,
// 新的请求可能会到达, 从而使得这个操作处于可完成状态
delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)

还是利用了purgatory,先不研究其实现细节,大致可以理解为,创建一个DelayedProduce对象,传入带offset和时间戳的消息集,设置timeout和响应回调,就能完成延迟生产。而purgatory只是用来确认是否完成,若没完成则将其扔进purgatory中。

也就是说,响应回调不再是像else分支(以及之前的错误处理分支)中一样由当前线程去调用,而是由DelayedProduce对象去调用,从而实现了异步的方式等待所有ISR收到最新的消息,避免leader的Handler线程阻塞在KafkaApis对请求的处理中。

另外,值得注意的是timeout是在构造这个DelayedProduce对象时才使用,也就是之前的写入本地日志的时间是不计算在内的,当然网络传输时间也是,可以回顾上一篇阅读笔记2.1 请求格式中翻译的官网对timeout的说明。

sendResponseCallback

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def sendResponseCallback(responseStatus: Map[TopicPartition, PartitionResponse]) {

// 合并 responseStatus 和之前认证失败或者主题不存在产生的错误响应
val mergedResponseStatus = responseStatus ++ unauthorizedTopicResponses ++ nonExistingTopicResponses
var errorInResponse = false

// 检查是否有错误响应, 若有则将 errorInResponse 置为 true, 并将错误写入debug日志
mergedResponseStatus.foreach { case (topicPartition, status) =>
if (status.error != Errors.NONE) {
errorInResponse = true
// 写入debug日志(略)
}
}

def produceResponseCallback(bandwidthThrottleTimeMs: Int) {
if (produceRequest.acks == 0) {
// acks为0意味着客户端无需等待服务端响应, 因此服务端无需操作
// 但是如果请求存在错误, 服务端需要关闭socket来通知客户端有错误发生, 然后更新元数据
if (errorInResponse) { // 存在错误响应
// 首先转换成 Map[TopicPartition, String], 其value为异常的类型名称, 然后将其字符串表示写入日志
val exceptionsSummary = mergedResponseStatus.map { case (topicPartition, status) =>
topicPartition -> status.error.exceptionName
}.mkString(", ")
// 写入info日志(略)
closeConnection(request, new ProduceResponse(mergedResponseStatus.asJava).errorCounts)
} else { // 不存在错误响应
sendNoOpResponseExemptThrottle(request)
}
} else { // acks为1或者-1
sendResponseMaybeThrottle(request, requestThrottleMs =>
new ProduceResponse(mergedResponseStatus.asJava, bandwidthThrottleTimeMs + requestThrottleMs))
}
}

// When this callback is triggered, the remote API call has completed
// 无论是在哪个处理分支, 这个回调函数必定是在远程API调用结束后才触发
request.apiRemoteCompleteTimeNanos = time.nanoseconds

quotas.produce.maybeRecordAndThrottle(
request.session.sanitizedUser, // session认证用户名(没配置SSL认证则是ANONYMOUS)
request.header.clientId,
numBytesAppended,
produceResponseCallback)
}

由于是接着之前的进行阅读,所以用到了一些之前创建的对象,见上一篇阅读笔记handleProduceRequest

  • unauthorizedTopicResponses:对调用KafkaApis.authorize方法认证失败的请求生成的错误响应;
  • nonExistingTopicResponses:对目标主题不在KafkaApis.metadataCache中的请求生产的错误响应;
  • numBytesAppended:请求的总字节数,包含header部分。

检测出是否有错误响应是为了传给produceResponseCallback,从而在acks为0时,关闭与客户端的socket连接来通知其更新元数据。而该回调被传入了ClientQuotaManager.maybeRecordAndThrottle方法,在未启用quotas的情况下会直接调用produceResponseCallback,分为以下3种情况:

  1. acks为0,且存在错误响应:关闭与客户端的连接,会引起客户端更新元数据;

  2. acks为0,且不存在错误响应:

    1
    sendNoOpResponseExemptThrottle(request)
    1
    2
    3
    4
    private def sendNoOpResponseExemptThrottle(request: RequestChannel.Request): Unit = {
    quotas.request.maybeRecordExempt(request)
    sendResponse(request, None)
    }

    会进入KafkaApis.sendResponseNone分支:

    1
    requestChannel.sendResponse(new RequestChannel.Response(request, None, NoOpAction, None))

    回顾网络层阅读的之Acceptor和Processor4.2 processNewResponses,如果响应的类型是NoOpAction,只会给Processor与客户端的连接Channel重新注册读事件,并不会发送响应给客户端。

  3. acks不为0:

    1
    2
    sendResponseMaybeThrottle(request,
    requestThrottleMs => new ProduceResponse(mergedResponseStatus.asJava, bandwidthThrottleTimeMs + requestThrottleMs))

    创建ProduceResponsethrottleMsbandwidthThrottleTimeMsrequestThrottleMs之和,这两者都有各自对应的quotas对象,若未启用则为0。最终也会进入KafkaApis.sendResponse中:

    1
    2
    3
    4
    5
    val responseSend = request.context.buildResponse(response)
    val responseString =
    if (RequestChannel.isRequestLoggingEnabled) Some(response.toString(request.context.apiVersion))
    else None
    requestChannel.sendResponse(new RequestChannel.Response(request, Some(responseSend), SendAction, responseString))

    SendAction类型的响应通过RequestChannel交给Processor,进一步发送给客户端。

总结

本篇阅读了处理Produce请求的流程,接着写入本地日志后的代码继续阅读:

写入本地日志后会返回处理结果,包含了每个请求写入的分区的相关状态,新增了消息集的baseOffset和写入日志的时间戳。对于acks字段为-1的情况,将timeout字段/消息集以及发送响应的回调丢给DelayedOperation对象进行异步的延迟操作,并通过purgatory字段检查异步处理的结果。

无论是KafkaApis本身,还是DelayedOperation,处理完后都会调用sendResponseCallback,acks不为0则根据Produce响应协议构造响应发送给客户端,acks为0则根据是否有错误响应而有不同的行为,若不包含错误响应则不进行操作,否则关闭socket,触发客户端重新获取元数据。

至此,完成了服务端对Produce请求的阅读,但是有不少细节没有深入,有待进一步阅读:

  • DelayedOperationDelayedOperationPurgatory:延迟操作的实现;
  • Log类,对本地日志目录和日志片段(segment)文件直接操作;
  • Partition类,管理了分区的副本broker,还有leader epoch等。