Avoid queueing into circbuffer when the channel is about to close

This commit is contained in:
Matt Johnston 2015-11-19 23:52:52 +08:00
parent 87373be960
commit 90c3a74b2a

View File

@ -42,7 +42,7 @@ static void send_msg_channel_open_failure(unsigned int remotechan, int reason,
static void send_msg_channel_open_confirmation(struct Channel* channel, static void send_msg_channel_open_confirmation(struct Channel* channel,
unsigned int recvwindow, unsigned int recvwindow,
unsigned int recvmaxpacket); unsigned int recvmaxpacket);
static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf, static int writechannel(struct Channel* channel, int fd, circbuffer *cbuf,
const unsigned char *moredata, unsigned int *morelen); const unsigned char *moredata, unsigned int *morelen);
static void send_msg_channel_window_adjust(struct Channel *channel, static void send_msg_channel_window_adjust(struct Channel *channel,
unsigned int incr); unsigned int incr);
@ -100,15 +100,6 @@ void chancleanup() {
TRACE(("leave chancleanup")) TRACE(("leave chancleanup"))
} }
static void
chan_initwritebuf(struct Channel *channel)
{
dropbear_assert(channel->writebuf->size == 0 && channel->recvwindow == 0);
cbuf_free(channel->writebuf);
channel->writebuf = cbuf_new(opts.recv_window);
channel->recvwindow = opts.recv_window;
}
/* Create a new channel entry, send a reply confirm or failure */ /* Create a new channel entry, send a reply confirm or failure */
/* If remotechan, transwindow and transmaxpacket are not know (for a new /* If remotechan, transwindow and transmaxpacket are not know (for a new
* outgoing connection, with them to be filled on confirmation), they should * outgoing connection, with them to be filled on confirmation), they should
@ -167,8 +158,8 @@ static struct Channel* newchannel(unsigned int remotechan,
newchan->await_open = 0; newchan->await_open = 0;
newchan->flushing = 0; newchan->flushing = 0;
newchan->writebuf = cbuf_new(0); /* resized later by chan_initwritebuf */ newchan->writebuf = cbuf_new(opts.recv_window);
newchan->recvwindow = 0; newchan->recvwindow = opts.recv_window;
newchan->extrabuf = NULL; /* The user code can set it up */ newchan->extrabuf = NULL; /* The user code can set it up */
newchan->recvdonelen = 0; newchan->recvdonelen = 0;
@ -379,7 +370,6 @@ void channel_connect_done(int result, int sock, void* user_data, const char* UNU
{ {
channel->readfd = channel->writefd = sock; channel->readfd = channel->writefd = sock;
channel->conn_pending = NULL; channel->conn_pending = NULL;
chan_initwritebuf(channel);
send_msg_channel_open_confirmation(channel, channel->recvwindow, send_msg_channel_open_confirmation(channel, channel->recvwindow,
channel->recvmaxpacket); channel->recvmaxpacket);
TRACE(("leave channel_connect_done: success")) TRACE(("leave channel_connect_done: success"))
@ -436,7 +426,7 @@ static void send_msg_channel_eof(struct Channel *channel) {
} }
#ifndef HAVE_WRITEV #ifndef HAVE_WRITEV
static void writechannel_fallback(struct Channel* channel, int fd, circbuffer *cbuf, static int writechannel_fallback(struct Channel* channel, int fd, circbuffer *cbuf,
const unsigned char *UNUSED(moredata), unsigned int *morelen) { const unsigned char *UNUSED(moredata), unsigned int *morelen) {
unsigned char *circ_p1, *circ_p2; unsigned char *circ_p1, *circ_p2;
@ -455,23 +445,24 @@ static void writechannel_fallback(struct Channel* channel, int fd, circbuffer *c
if (errno != EINTR && errno != EAGAIN) { if (errno != EINTR && errno != EAGAIN) {
TRACE(("channel IO write error fd %d %s", fd, strerror(errno))) TRACE(("channel IO write error fd %d %s", fd, strerror(errno)))
close_chan_fd(channel, fd, SHUT_WR); close_chan_fd(channel, fd, SHUT_WR);
return DROPBEAR_FAILURE;
}
} }
} else {
cbuf_incrread(cbuf, written); cbuf_incrread(cbuf, written);
channel->recvdonelen += written; channel->recvdonelen += written;
} return DROPBEAR_SUCCESS;
} }
#endif /* !HAVE_WRITEV */ #endif /* !HAVE_WRITEV */
#ifdef HAVE_WRITEV #ifdef HAVE_WRITEV
static void writechannel_writev(struct Channel* channel, int fd, circbuffer *cbuf, static int writechannel_writev(struct Channel* channel, int fd, circbuffer *cbuf,
const unsigned char *moredata, unsigned int *morelen) { const unsigned char *moredata, unsigned int *morelen) {
struct iovec iov[3]; struct iovec iov[3];
unsigned char *circ_p1, *circ_p2; unsigned char *circ_p1, *circ_p2;
unsigned int circ_len1, circ_len2; unsigned int circ_len1, circ_len2;
int io_count = 0; int io_count = 0;
int cbuf_written;
ssize_t written; ssize_t written;
cbuf_readptrs(cbuf, &circ_p1, &circ_len1, &circ_p2, &circ_len2); cbuf_readptrs(cbuf, &circ_p1, &circ_len1, &circ_p2, &circ_len2);
@ -503,7 +494,7 @@ static void writechannel_writev(struct Channel* channel, int fd, circbuffer *cbu
From common_recv_msg_channel_data() then channelio(). From common_recv_msg_channel_data() then channelio().
The second call may not have any data to write, so we just return. */ The second call may not have any data to write, so we just return. */
TRACE(("leave writechannel, no data")) TRACE(("leave writechannel, no data"))
return; return DROPBEAR_SUCCESS;
} }
if (morelen) { if (morelen) {
@ -517,29 +508,32 @@ static void writechannel_writev(struct Channel* channel, int fd, circbuffer *cbu
if (errno != EINTR && errno != EAGAIN) { if (errno != EINTR && errno != EAGAIN) {
TRACE(("channel IO write error fd %d %s", fd, strerror(errno))) TRACE(("channel IO write error fd %d %s", fd, strerror(errno)))
close_chan_fd(channel, fd, SHUT_WR); close_chan_fd(channel, fd, SHUT_WR);
return DROPBEAR_FAILURE;
} }
} else { }
int cbuf_written = MIN(circ_len1+circ_len2, (unsigned int)written);
cbuf_written = MIN(circ_len1+circ_len2, (unsigned int)written);
cbuf_incrread(cbuf, cbuf_written); cbuf_incrread(cbuf, cbuf_written);
if (morelen) { if (morelen) {
*morelen = written - cbuf_written; *morelen = written - cbuf_written;
} }
channel->recvdonelen += written; channel->recvdonelen += written;
} return DROPBEAR_SUCCESS;
} }
#endif /* HAVE_WRITEV */ #endif /* HAVE_WRITEV */
/* Called to write data out to the local side of the channel. /* Called to write data out to the local side of the channel.
Writes the circular buffer contents and also the "moredata" buffer Writes the circular buffer contents and also the "moredata" buffer
if not null. Will ignore EAGAIN */ if not null. Will ignore EAGAIN.
static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf, Returns DROPBEAR_FAILURE if writing to fd had an error and the channel is being closed, DROPBEAR_SUCCESS otherwise */
static int writechannel(struct Channel* channel, int fd, circbuffer *cbuf,
const unsigned char *moredata, unsigned int *morelen) { const unsigned char *moredata, unsigned int *morelen) {
int ret = DROPBEAR_SUCCESS;
TRACE(("enter writechannel fd %d", fd)) TRACE(("enter writechannel fd %d", fd))
#ifdef HAVE_WRITEV #ifdef HAVE_WRITEV
writechannel_writev(channel, fd, cbuf, moredata, morelen); ret = writechannel_writev(channel, fd, cbuf, moredata, morelen);
#else #else
writechannel_fallback(channel, fd, cbuf, moredata, morelen); ret = writechannel_fallback(channel, fd, cbuf, moredata, morelen);
#endif #endif
/* Window adjust handling */ /* Window adjust handling */
@ -555,6 +549,7 @@ static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf,
channel->recvwindow <= cbuf_getavail(channel->extrabuf)); channel->recvwindow <= cbuf_getavail(channel->extrabuf));
TRACE(("leave writechannel")) TRACE(("leave writechannel"))
return ret;
} }
@ -829,6 +824,7 @@ void common_recv_msg_channel_data(struct Channel *channel, int fd,
unsigned int buflen; unsigned int buflen;
unsigned int len; unsigned int len;
unsigned int consumed; unsigned int consumed;
int res;
TRACE(("enter recv_msg_channel_data")) TRACE(("enter recv_msg_channel_data"))
@ -861,7 +857,7 @@ void common_recv_msg_channel_data(struct Channel *channel, int fd,
/* Attempt to write the data immediately without having to put it in the circular buffer */ /* Attempt to write the data immediately without having to put it in the circular buffer */
consumed = datalen; consumed = datalen;
writechannel(channel, fd, cbuf, buf_getptr(ses.payload, datalen), &consumed); res = writechannel(channel, fd, cbuf, buf_getptr(ses.payload, datalen), &consumed);
datalen -= consumed; datalen -= consumed;
buf_incrpos(ses.payload, consumed); buf_incrpos(ses.payload, consumed);
@ -869,7 +865,9 @@ void common_recv_msg_channel_data(struct Channel *channel, int fd,
/* We may have to run throught twice, if the buffer wraps around. Can't /* We may have to run throught twice, if the buffer wraps around. Can't
* just "leave it for next time" like with writechannel, since this * just "leave it for next time" like with writechannel, since this
* is payload data */ * is payload data.
* If the writechannel() failed then remaining data is discarded */
if (res == DROPBEAR_SUCCESS) {
len = datalen; len = datalen;
while (len > 0) { while (len > 0) {
buflen = cbuf_writelen(cbuf); buflen = cbuf_writelen(cbuf);
@ -881,6 +879,7 @@ void common_recv_msg_channel_data(struct Channel *channel, int fd,
buf_incrpos(ses.payload, buflen); buf_incrpos(ses.payload, buflen);
len -= buflen; len -= buflen;
} }
}
TRACE(("leave recv_msg_channel_data")) TRACE(("leave recv_msg_channel_data"))
} }
@ -993,8 +992,6 @@ void recv_msg_channel_open() {
channel->prio = DROPBEAR_CHANNEL_PRIO_BULK; channel->prio = DROPBEAR_CHANNEL_PRIO_BULK;
} }
chan_initwritebuf(channel);
/* success */ /* success */
send_msg_channel_open_confirmation(channel, channel->recvwindow, send_msg_channel_open_confirmation(channel, channel->recvwindow,
channel->recvmaxpacket); channel->recvmaxpacket);
@ -1137,7 +1134,6 @@ int send_msg_channel_open_init(int fd, const struct ChanType *type) {
/* Outbound opened channels don't make use of in-progress connections, /* Outbound opened channels don't make use of in-progress connections,
* we can set it up straight away */ * we can set it up straight away */
chan_initwritebuf(chan);
/* set fd non-blocking */ /* set fd non-blocking */
setnonblocking(fd); setnonblocking(fd);