逐步实现TCP服务端Step02-2:改进

到目前为止服务端处理消息的情况是这样的:

先从socket缓冲中尽力取字节,把取到的字节放到一个单字节数组中去(recv字节队列)。然后,从这个数组中取字节,组成消息。

每次要发送消息时,都要先尝试发送post字节队列中残留的字节,然后再去发送本次实际要发送的字节。如果此次发送没有将期望的字节量全部发出,就把剩余的字节存到post队列中,等待下次调用SendOneMessageEx时再尝试发送。

发送字节的操作是在SendOneMessageEx中进行的,如果SendOneMessageEx不被调用的话,残留的字节就没机会被发送。这个地方不太合理,说白了SendOneMessageEx也应该拆成两个过程,就像之前的RecvOneMessage一样。得到的好处就是:不论SendOneMessageEx何时被调用,都不会耽误字节的发送工作。 SetOneMessage是一种更高层次的操作,更偏向“业务”。或者说,对这个函数的调用往往发生在业务的处理过程中。这里再增加一个函数:ProcessOneMessage,它专注于处理基于当前消息的相关事务,而对SetOneMessage函数的调用,应该发生在这个函数中。

这样一来,主要的流程将如下图所示:

base.h

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
#ifndef __BASE_H__
#define __BASE_H__

#include <unistd.h>
#include <fcntl.h>
#include <cstring>
#include <cerrno>

#define MSG_PREFIX_LEN      (4)
#define MAX_RECV_BUF_LEN    (128)
#define MAX_POST_BUF_LEN    (128)

unsigned char g_recv_buf[MAX_RECV_BUF_LEN];
unsigned char g_post_buf[MAX_POST_BUF_LEN];
int g_read_begin = 0;
int g_read_end   = 0;
int g_post_begin = 0;
int g_post_end   = 0;

// retval:
//  0 - success
// -1 - sock invalid 
// -2 - recv buf is full
// -3 - client close
// -4 - send error
int RecvBytes(int sock)
{
    if (sock <= 0) 
        return -1;
    if (g_read_begin == g_read_end) {
        g_read_begin = 0;
        g_read_end   = 0;
    }
    int recvd_bytes_cnt = 0;
    int retval = 0;
    do {
        if (g_read_end >= MAX_RECV_BUF_LEN) {
            retval = -2;
            break;
        }
        recvd_bytes_cnt = recv(sock, 
            &g_recv_buf[g_read_end], 
            MAX_RECV_BUF_LEN - g_read_end, 0);
        if (recvd_bytes_cnt > 0) {
            g_read_end += recvd_bytes_cnt;
        } else if (0 == recvd_bytes_cnt) {
            retval = -3;
            break;
        } else if (EAGAIN != errno) {
            retval = -4;
            break;
        }
    } while (recvd_bytes_cnt > 0);
    return retval;
}

int GetOneMessage(unsigned int& len, unsigned char* msg)
{
    if (!msg) return -1;
    int bytes_cnt = g_read_end - g_read_begin;
    if (bytes_cnt <= MSG_PREFIX_LEN)
        return -2;
    int max_buf_len = len;
    len = ntohl(*((int*)&g_recv_buf[g_read_begin]));
    if (len <= 0) {
        g_read_begin = g_read_end = 0;
        return -3;
    }
    if (max_buf_len < len)
        return -4;
    if (int(bytes_cnt - MSG_PREFIX_LEN) < len) {
        // check if the rest of g_recv_buf is enough.
        if (g_read_begin + len + MSG_PREFIX_LEN 
            >= MAX_RECV_BUF_LEN) {
            memmove(g_recv_buf, 
                &g_recv_buf[g_read_begin], bytes_cnt);
            g_read_begin = 0;
            g_read_end = bytes_cnt;
            if (g_read_begin + len + MSG_PREFIX_LEN 
                > MAX_RECV_BUF_LEN) {
                return -5;
            }
        }
        return -6;
    }
    memcpy(msg, &g_recv_buf[g_read_begin + MSG_PREFIX_LEN], len);
    g_read_begin += MSG_PREFIX_LEN;
    g_read_begin += len;
    if (g_read_begin == g_read_end)
        g_read_begin = g_read_end = 0;
    return 0;
}

int SendReservedBytes(int sock)
{
    if (sock <= 0) return -1;
    int sent_bytes = 0;
    int left_bytes = g_post_end - g_post_begin; 
    if (!left_bytes) return 0;
    unsigned char* tmp_post_buf = &g_post_buf[g_post_begin]; 
    while (1) {
        sent_bytes = send(sock, 
            tmp_post_buf, left_bytes, 0);
        if (sent_bytes > 0) {
            tmp_post_buf += sent_bytes;
            left_bytes -= sent_bytes;
            g_post_begin += sent_bytes;
            if (!left_bytes) {
                g_post_begin = g_post_end = 0;
                return 0;
            }
        } else if (sent_bytes < 0) {
            if (EAGAIN == errno)
                return -2;
            else
                return -3;
        } else 
            return -3;
    }
    return 0;
}

int ReserveBytes(int bytes_cnt, const unsigned char* bytes)
{
    int existing_bytes_cnt = g_post_end - g_post_begin;
    if (existing_bytes_cnt + bytes_cnt < MAX_POST_BUF_LEN) {
        if (g_post_begin > 0) {
            memmove(g_post_buf, &g_post_buf[g_post_begin], 
                    existing_bytes_cnt);
            g_post_begin = 0;
        }
        g_post_end = existing_bytes_cnt;
        memcpy(&g_post_buf[g_post_end], bytes, bytes_cnt);
        g_post_end += bytes_cnt;
    } else return -1;
    return 0;
}

int SetOneMessage(unsigned int len, const unsigned char* msg)
{
    int sent_bytes = 0;
    int left_bytes = MSG_PREFIX_LEN + len;
    unsigned char* buf 
        = new unsigned char[left_bytes];
    if (!buf) return -1;
    *((int*)buf) = htonl(len);
    memcpy(&buf[MSG_PREFIX_LEN], msg, len);
    if (ReserveBytes(left_bytes, buf) < 0) {
        delete[] buf;
        return -1;
    }
    return 0;
}

int SendBytes(int sock)
{
    if (sock <= 0) return -1;
    int sent_bytes = 0;
    int left_bytes = g_post_end - g_post_begin; 
    if (!left_bytes) return 0;
    unsigned char* tmp_post_buf = &g_post_buf[g_post_begin]; 
    while (1) {
        sent_bytes = send(sock, 
            tmp_post_buf, left_bytes, 0);
        if (sent_bytes > 0) {
            tmp_post_buf += sent_bytes;
            left_bytes -= sent_bytes;
            g_post_begin += sent_bytes;
            if (!left_bytes) {
                g_post_begin = g_post_end = 0;
                return 0;
            }
        } else if (sent_bytes < 0) {
            if (EAGAIN == errno)
                return -2;
            else
                return -3;
        } else 
            return -3;
    }
    return 0;
}

int SendOneMessageEx(int sock, unsigned int len, const unsigned char* msg)
{    
    if (!len) return -1;
    if (!msg) return -1;
    if (sock <= 0) return -1;
    int sent_bytes = 0;
    int left_bytes = MSG_PREFIX_LEN + len;
    unsigned char* buf 
        = new unsigned char[left_bytes];
    unsigned char* tmpbuf = buf; 
    if (!buf) return -1;
    *((int*)tmpbuf) = htonl(len);
    memcpy(&tmpbuf[MSG_PREFIX_LEN], msg, len);
    int retval = SendReservedBytes(sock);
    if (retval < 0) {
        if (-2 == retval) {
            if (ReserveBytes(left_bytes, tmpbuf) < 0) {
                delete[] buf;
                return -1;
            }
        } else { 
            delete[] buf;
            return -1;
        }
        delete[] buf;
        return -2;
    }
    while (left_bytes > 0) {
        sent_bytes = send(sock, tmpbuf, left_bytes, 0);
        if (sent_bytes > 0) {
            tmpbuf += sent_bytes;
            left_bytes -= sent_bytes;
        } else if (sent_bytes < 0) {
            if (EAGAIN == errno) {
                if (ReserveBytes(left_bytes, tmpbuf) < 0) {
                    delete[] buf;
                    return -1;
                }
                delete[] buf;
                return -2;
            }
        } else { 
            delete[] buf;
            return -1;
        }
    }
    delete[] buf;
    return 0;
}

int SetSockNonBlock(int sock)
{
    if (sock <= 0) return -1;
    int flags = fcntl(sock, F_GETFL, 0);
    if (flags < 0) return -1;
    flags |= O_NONBLOCK;
    return fcntl(sock, F_SETFL, flags);
}

int RecvOneMessage(int sock, unsigned int& len, unsigned char* msg)
{
    if (!len) return -1;
    if (!msg) return -1;
    if (sock <= 0) return -1;
    int recvd_bytes_cnt = 0;
    int total_recvd_bytes_cnt = 0;
    unsigned char prefix[MSG_PREFIX_LEN];
    int max_buf_len = MSG_PREFIX_LEN;
    while (total_recvd_bytes_cnt != MSG_PREFIX_LEN) {
        if ((recvd_bytes_cnt = recv(sock, 
            &prefix[total_recvd_bytes_cnt], 
            max_buf_len - total_recvd_bytes_cnt, 0)) < 0) {
            return -1;
        }
        if (0 == recvd_bytes_cnt) {
            return 0;
        }
        total_recvd_bytes_cnt += recvd_bytes_cnt;
    }
    max_buf_len = len;
    len = ntohl(*((unsigned int*)prefix));
    if (len > max_buf_len) {
        return -1;
    }
    total_recvd_bytes_cnt = 0;
    while (total_recvd_bytes_cnt != len) {
        if ((recvd_bytes_cnt = recv(sock, 
            &msg[total_recvd_bytes_cnt], 
            max_buf_len - total_recvd_bytes_cnt, 0)) < 0) {
            return -1;
        }
        if (0 == recvd_bytes_cnt) {
            return 0;
        }
        total_recvd_bytes_cnt += recvd_bytes_cnt;
    }
    return total_recvd_bytes_cnt + MSG_PREFIX_LEN;
}

int SendOneMessage(int sock, unsigned int len, const unsigned char* msg)
{    
    if (!len) return -1;
    if (!msg) return -1;
    if (sock <= 0) return -1;
    int total_sent_bytes_cnt = 0;
    unsigned char* buf 
        = new unsigned char[MSG_PREFIX_LEN + len];
    unsigned char* tmpbuf = buf;
    if (!buf) return -1;
    *((int*)tmpbuf) = htonl(len);
    memcpy(&tmpbuf[MSG_PREFIX_LEN], msg, len);
    len += MSG_PREFIX_LEN;
    if ((total_sent_bytes_cnt 
        = send(sock, tmpbuf, len, 0)) < 0) {
        delete[] buf;
        return -1;
    }
    delete[] buf;
    return total_sent_bytes_cnt;
}

#endif /* __BASE_H__ */

s.cc

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <unistd.h>
#include <sys/types.h>  
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <errno.h>
#include "base.h"

int ProcessOneMessage(unsigned int len, const unsigned char* msg);

int main(int argc, char**argv)
{
    if (argc != 2) {
        std::clog << "Usage: " 
            << argv[0] << " port" << std::endl;
        return 0;
    }
    int svr_sock, cli_sock;
    sockaddr_in svr_addr, cli_addr;
    if ((svr_sock = socket(AF_INET, 
        SOCK_STREAM, 0)) < 0) {
        return 1;
    }
    memset(&svr_addr, 0, sizeof(svr_addr));
    svr_addr.sin_family = AF_INET;
    svr_addr.sin_port = htons(atoi(argv[1]));
    svr_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    if (bind(svr_sock, 
        (const sockaddr*)&svr_addr, 
        sizeof(svr_addr)) < 0) {
        return 1;
    }
    if (listen(svr_sock, 4) < 0) {
        return 1;
    }
    while (1) {
        memset(&cli_addr, 0, sizeof(cli_addr));
        socklen_t addr_len = sizeof(cli_addr);
        if ((cli_sock = accept(svr_sock, 
            (sockaddr*)&cli_addr, &addr_len)) < 0) {
            return 1;
        }
        if (SetSockNonBlock(cli_sock) < 0) {
            std::cerr << "set sock nonblock failed!" << std::endl;
            return 1;
        }
        unsigned char msg[128] = {0};
        while (1) {
            unsigned int msg_len = sizeof(msg);
            memset(msg, 0, msg_len);
            int retval = RecvBytes(cli_sock);
            if (-2 == retval) {
                std::clog << "recv buf is full." << std::endl;
            } else if (-3 == retval) {
                std::clog << "cli close!" << std::endl;
                break;
            } else if (retval < 0) 
                return 1;
            bool is_get_message = true;
            retval = GetOneMessage(msg_len, msg);
            if (retval < 0) {
                if (-2 == retval || -6 == retval) {
                    is_get_message = false;
                } else return 1;
            }
            // Process current message.
            if (is_get_message) {
                if (ProcessOneMessage(msg_len, msg) < 0) 
                    return 1;
            }
            retval = SendBytes(cli_sock);
            if (retval < 0) {
                if (-2 == retval) {
                    continue;
                }
                return 1;
            }
        }
        close(cli_sock);
        g_read_begin = g_read_end = 0;
    }
    return 0;
}

int ProcessOneMessage(unsigned int len, const unsigned char* msg)
{
    std::clog << msg << std::endl;
    return SetOneMessage(len, msg);
}

<==  index  ==>