/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Robert Collins
  • Date: 2007-04-19 02:27:44 UTC
  • mto: This revision was merged to the branch mainline in revision 2426.
  • Revision ID: robertc@robertcollins.net-20070419022744-pfdqz42kp1wizh43
``make docs`` now creates a man page at ``man1/bzr.1`` fixing bug 107388.
(Robert Collins)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006,2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
17
"""The 'medium' layer for the smart servers and clients.
18
18
 
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
import errno
28
 
import io
29
27
import os
 
28
import socket
30
29
import sys
31
 
import time
 
30
from bzrlib import errors
 
31
from bzrlib.smart.protocol import SmartServerRequestProtocolOne
32
32
 
33
33
try:
34
 
    import _thread
35
 
except ImportError:
36
 
    import thread as _thread
37
 
 
38
 
import breezy
39
 
from ...lazy_import import lazy_import
40
 
lazy_import(globals(), """
41
 
import select
42
 
import socket
43
 
import weakref
44
 
 
45
 
from breezy import (
46
 
    debug,
47
 
    trace,
48
 
    transport,
49
 
    ui,
50
 
    urlutils,
51
 
    )
52
 
from breezy.i18n import gettext
53
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
54
 
from breezy.transport import ssh
55
 
""")
56
 
from ... import (
57
 
    errors,
58
 
    osutils,
59
 
    )
60
 
 
61
 
# Throughout this module buffer size parameters are either limited to be at
62
 
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
63
 
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
64
 
# from non-sockets as well.
65
 
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
66
 
 
67
 
 
68
 
class HpssVfsRequestNotAllowed(errors.BzrError):
69
 
 
70
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
71
 
            "%(method)s, %(arguments)s.")
72
 
 
73
 
    def __init__(self, method, arguments):
74
 
        self.method = method
75
 
        self.arguments = arguments
76
 
 
77
 
 
78
 
def _get_protocol_factory_for_bytes(bytes):
79
 
    """Determine the right protocol factory for 'bytes'.
80
 
 
81
 
    This will return an appropriate protocol factory depending on the version
82
 
    of the protocol being used, as determined by inspecting the given bytes.
83
 
    The bytes should have at least one newline byte (i.e. be a whole line),
84
 
    otherwise it's possible that a request will be incorrectly identified as
85
 
    version 1.
86
 
 
87
 
    Typical use would be::
88
 
 
89
 
         factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
90
 
         server_protocol = factory(transport, write_func, root_client_path)
91
 
         server_protocol.accept_bytes(unused_bytes)
92
 
 
93
 
    :param bytes: a str of bytes of the start of the request.
94
 
    :returns: 2-tuple of (protocol_factory, unused_bytes).  protocol_factory is
95
 
        a callable that takes three args: transport, write_func,
96
 
        root_client_path.  unused_bytes are any bytes that were not part of a
97
 
        protocol version marker.
98
 
    """
99
 
    if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
100
 
        protocol_factory = protocol.build_server_protocol_three
101
 
        bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
102
 
    elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
103
 
        protocol_factory = protocol.SmartServerRequestProtocolTwo
104
 
        bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
105
 
    else:
106
 
        protocol_factory = protocol.SmartServerRequestProtocolOne
107
 
    return protocol_factory, bytes
108
 
 
109
 
 
110
 
def _get_line(read_bytes_func):
111
 
    """Read bytes using read_bytes_func until a newline byte.
112
 
 
113
 
    This isn't particularly efficient, so should only be used when the
114
 
    expected size of the line is quite short.
115
 
 
116
 
    :returns: a tuple of two strs: (line, excess)
117
 
    """
118
 
    newline_pos = -1
119
 
    bytes = b''
120
 
    while newline_pos == -1:
121
 
        new_bytes = read_bytes_func(1)
122
 
        bytes += new_bytes
123
 
        if new_bytes == b'':
124
 
            # Ran out of bytes before receiving a complete line.
125
 
            return bytes, b''
126
 
        newline_pos = bytes.find(b'\n')
127
 
    line = bytes[:newline_pos + 1]
128
 
    excess = bytes[newline_pos + 1:]
129
 
    return line, excess
130
 
 
131
 
 
132
 
class SmartMedium(object):
133
 
    """Base class for smart protocol media, both client- and server-side."""
134
 
 
135
 
    def __init__(self):
136
 
        self._push_back_buffer = None
137
 
 
138
 
    def _push_back(self, data):
139
 
        """Return unused bytes to the medium, because they belong to the next
140
 
        request(s).
141
 
 
142
 
        This sets the _push_back_buffer to the given bytes.
143
 
        """
144
 
        if not isinstance(data, bytes):
145
 
            raise TypeError(data)
146
 
        if self._push_back_buffer is not None:
147
 
            raise AssertionError(
148
 
                "_push_back called when self._push_back_buffer is %r"
149
 
                % (self._push_back_buffer,))
150
 
        if data == b'':
151
 
            return
152
 
        self._push_back_buffer = data
153
 
 
154
 
    def _get_push_back_buffer(self):
155
 
        if self._push_back_buffer == b'':
156
 
            raise AssertionError(
157
 
                '%s._push_back_buffer should never be the empty string, '
158
 
                'which can be confused with EOF' % (self,))
159
 
        bytes = self._push_back_buffer
160
 
        self._push_back_buffer = None
161
 
        return bytes
162
 
 
163
 
    def read_bytes(self, desired_count):
164
 
        """Read some bytes from this medium.
165
 
 
166
 
        :returns: some bytes, possibly more or less than the number requested
167
 
            in 'desired_count' depending on the medium.
168
 
        """
169
 
        if self._push_back_buffer is not None:
170
 
            return self._get_push_back_buffer()
171
 
        bytes_to_read = min(desired_count, _MAX_READ_SIZE)
172
 
        return self._read_bytes(bytes_to_read)
173
 
 
174
 
    def _read_bytes(self, count):
175
 
        raise NotImplementedError(self._read_bytes)
176
 
 
177
 
    def _get_line(self):
178
 
        """Read bytes from this request's response until a newline byte.
179
 
 
180
 
        This isn't particularly efficient, so should only be used when the
181
 
        expected size of the line is quite short.
182
 
 
183
 
        :returns: a string of bytes ending in a newline (byte 0x0A).
184
 
        """
185
 
        line, excess = _get_line(self.read_bytes)
186
 
        self._push_back(excess)
187
 
        return line
188
 
 
189
 
    def _report_activity(self, bytes, direction):
190
 
        """Notify that this medium has activity.
191
 
 
192
 
        Implementations should call this from all methods that actually do IO.
193
 
        Be careful that it's not called twice, if one method is implemented on
194
 
        top of another.
195
 
 
196
 
        :param bytes: Number of bytes read or written.
197
 
        :param direction: 'read' or 'write' or None.
198
 
        """
199
 
        ui.ui_factory.report_transport_activity(self, bytes, direction)
200
 
 
201
 
 
202
 
_bad_file_descriptor = (errno.EBADF,)
203
 
if sys.platform == 'win32':
204
 
    # Given on Windows if you pass a closed socket to select.select. Probably
205
 
    # also given if you pass a file handle to select.
206
 
    WSAENOTSOCK = 10038
207
 
    _bad_file_descriptor += (WSAENOTSOCK,)
208
 
 
209
 
 
210
 
class SmartServerStreamMedium(SmartMedium):
 
34
    from bzrlib.transport import ssh
 
35
except errors.ParamikoNotPresent:
 
36
    # no paramiko.  SmartSSHClientMedium will break.
 
37
    pass
 
38
 
 
39
 
 
40
class SmartServerStreamMedium(object):
211
41
    """Handles smart commands coming over a stream.
212
42
 
213
43
    The stream may be a pipe connected to sshd, or a tcp socket, or an
216
46
    One instance is created for each connected client; it can serve multiple
217
47
    requests in the lifetime of the connection.
218
48
 
219
 
    The server passes requests through to an underlying backing transport,
 
49
    The server passes requests through to an underlying backing transport, 
220
50
    which will typically be a LocalTransport looking at the server's filesystem.
221
 
 
222
 
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
223
 
        but not used yet, or None if there are no buffered bytes.  Subclasses
224
 
        should make sure to exhaust this buffer before reading more bytes from
225
 
        the stream.  See also the _push_back method.
226
51
    """
227
52
 
228
 
    _timer = time.time
229
 
 
230
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
53
    def __init__(self, backing_transport):
231
54
        """Construct new server.
232
55
 
233
56
        :param backing_transport: Transport for the directory served.
234
57
        """
235
58
        # backing_transport could be passed to serve instead of __init__
236
59
        self.backing_transport = backing_transport
237
 
        self.root_client_path = root_client_path
238
60
        self.finished = False
239
 
        if timeout is None:
240
 
            raise AssertionError('You must supply a timeout.')
241
 
        self._client_timeout = timeout
242
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
243
 
        SmartMedium.__init__(self)
244
61
 
245
62
    def serve(self):
246
63
        """Serve requests until the client disconnects."""
249
66
        from sys import stderr
250
67
        try:
251
68
            while not self.finished:
252
 
                server_protocol = self._build_protocol()
253
 
                self._serve_one_request(server_protocol)
254
 
        except errors.ConnectionTimeout as e:
255
 
            trace.note('%s' % (e,))
256
 
            trace.log_exception_quietly()
257
 
            self._disconnect_client()
258
 
            # We reported it, no reason to make a big fuss.
259
 
            return
260
 
        except Exception as e:
 
69
                protocol = SmartServerRequestProtocolOne(self.backing_transport,
 
70
                                                         self._write_out)
 
71
                self._serve_one_request(protocol)
 
72
        except Exception, e:
261
73
            stderr.write("%s terminating on exception %s\n" % (self, e))
262
74
            raise
263
 
        self._disconnect_client()
264
 
 
265
 
    def _stop_gracefully(self):
266
 
        """When we finish this message, stop looking for more."""
267
 
        trace.mutter('Stopping %s' % (self,))
268
 
        self.finished = True
269
 
 
270
 
    def _disconnect_client(self):
271
 
        """Close the current connection. We stopped due to a timeout/etc."""
272
 
        # The default implementation is a no-op, because that is all we used to
273
 
        # do when disconnecting from a client. I suppose we never had the
274
 
        # *server* initiate a disconnect, before
275
 
 
276
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
277
 
        """Wait for more bytes to be read, but timeout if none available.
278
 
 
279
 
        This allows us to detect idle connections, and stop trying to read from
280
 
        them, without setting the socket itself to non-blocking. This also
281
 
        allows us to specify when we watch for idle timeouts.
282
 
 
283
 
        :return: Did we timeout? (True if we timed out, False if there is data
284
 
            to be read)
285
 
        """
286
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
287
 
 
288
 
    def _build_protocol(self):
289
 
        """Identifies the version of the incoming request, and returns an
290
 
        a protocol object that can interpret it.
291
 
 
292
 
        If more bytes than the version prefix of the request are read, they will
293
 
        be fed into the protocol before it is returned.
294
 
 
295
 
        :returns: a SmartServerRequestProtocol.
296
 
        """
297
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
298
 
        if self.finished:
299
 
            # We're stopping, so don't try to do any more work
300
 
            return None
301
 
        bytes = self._get_line()
302
 
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
303
 
        protocol = protocol_factory(
304
 
            self.backing_transport, self._write_out, self.root_client_path)
305
 
        protocol.accept_bytes(unused_bytes)
306
 
        return protocol
307
 
 
308
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
309
 
        """select() on a file descriptor, waiting for nonblocking read()
310
 
 
311
 
        This will raise a ConnectionTimeout exception if we do not get a
312
 
        readable handle before timeout_seconds.
313
 
        :return: None
314
 
        """
315
 
        t_end = self._timer() + timeout_seconds
316
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
317
 
        rs = xs = None
318
 
        while not rs and not xs and self._timer() < t_end:
319
 
            if self.finished:
320
 
                return
321
 
            try:
322
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
323
 
            except (select.error, socket.error) as e:
324
 
                err = getattr(e, 'errno', None)
325
 
                if err is None and getattr(e, 'args', None) is not None:
326
 
                    # select.error doesn't have 'errno', it just has args[0]
327
 
                    err = e.args[0]
328
 
                if err in _bad_file_descriptor:
329
 
                    return  # Not a socket indicates read() will fail
330
 
                elif err == errno.EINTR:
331
 
                    # Interrupted, keep looping.
332
 
                    continue
333
 
                raise
334
 
            except ValueError:
335
 
                return  # Socket may already be closed
336
 
        if rs or xs:
337
 
            return
338
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
339
 
                                       % (timeout_seconds,))
340
75
 
341
76
    def _serve_one_request(self, protocol):
342
77
        """Read one request from input, process, send back a response.
343
 
 
 
78
        
344
79
        :param protocol: a SmartServerRequestProtocol.
345
80
        """
346
 
        if protocol is None:
347
 
            return
348
81
        try:
349
82
            self._serve_one_request_unguarded(protocol)
350
83
        except KeyboardInterrupt:
351
84
            raise
352
 
        except Exception as e:
 
85
        except Exception, e:
353
86
            self.terminate_due_to_error()
354
87
 
355
88
    def terminate_due_to_error(self):
356
89
        """Called when an unhandled exception from the protocol occurs."""
357
90
        raise NotImplementedError(self.terminate_due_to_error)
358
91
 
359
 
    def _read_bytes(self, desired_count):
360
 
        """Get some bytes from the medium.
361
 
 
362
 
        :param desired_count: number of bytes we want to read.
363
 
        """
364
 
        raise NotImplementedError(self._read_bytes)
365
 
 
366
92
 
367
93
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
368
94
 
369
 
    def __init__(self, sock, backing_transport, root_client_path='/',
370
 
                 timeout=None):
 
95
    def __init__(self, sock, backing_transport):
371
96
        """Constructor.
372
97
 
373
98
        :param sock: the socket the server will read from.  It will be put
374
99
            into blocking mode.
375
100
        """
376
 
        SmartServerStreamMedium.__init__(
377
 
            self, backing_transport, root_client_path=root_client_path,
378
 
            timeout=timeout)
 
101
        SmartServerStreamMedium.__init__(self, backing_transport)
 
102
        self.push_back = ''
379
103
        sock.setblocking(True)
380
104
        self.socket = sock
381
 
        # Get the getpeername now, as we might be closed later when we care.
382
 
        try:
383
 
            self._client_info = sock.getpeername()
384
 
        except socket.error:
385
 
            self._client_info = '<unknown>'
386
 
 
387
 
    def __str__(self):
388
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
389
 
 
390
 
    def __repr__(self):
391
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
392
 
                                     self._client_info)
393
105
 
394
106
    def _serve_one_request_unguarded(self, protocol):
395
107
        while protocol.next_read_size():
396
 
            # We can safely try to read large chunks.  If there is less data
397
 
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
398
 
            # short read immediately rather than block.
399
 
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
400
 
            if bytes == b'':
401
 
                self.finished = True
402
 
                return
403
 
            protocol.accept_bytes(bytes)
404
 
 
405
 
        self._push_back(protocol.unused_data)
406
 
 
407
 
    def _disconnect_client(self):
408
 
        """Close the current connection. We stopped due to a timeout/etc."""
409
 
        self.socket.close()
410
 
 
411
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
412
 
        """Wait for more bytes to be read, but timeout if none available.
413
 
 
414
 
        This allows us to detect idle connections, and stop trying to read from
415
 
        them, without setting the socket itself to non-blocking. This also
416
 
        allows us to specify when we watch for idle timeouts.
417
 
 
418
 
        :return: None, this will raise ConnectionTimeout if we time out before
419
 
            data is available.
420
 
        """
421
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
422
 
 
423
 
    def _read_bytes(self, desired_count):
424
 
        return osutils.read_bytes_from_socket(
425
 
            self.socket, self._report_activity)
426
 
 
 
108
            if self.push_back:
 
109
                protocol.accept_bytes(self.push_back)
 
110
                self.push_back = ''
 
111
            else:
 
112
                bytes = self.socket.recv(4096)
 
113
                if bytes == '':
 
114
                    self.finished = True
 
115
                    return
 
116
                protocol.accept_bytes(bytes)
 
117
        
 
118
        self.push_back = protocol.excess_buffer
 
119
    
427
120
    def terminate_due_to_error(self):
 
121
        """Called when an unhandled exception from the protocol occurs."""
428
122
        # TODO: This should log to a server log file, but no such thing
429
123
        # exists yet.  Andrew Bennetts 2006-09-29.
430
124
        self.socket.close()
431
125
        self.finished = True
432
126
 
433
127
    def _write_out(self, bytes):
434
 
        tstart = osutils.perf_counter()
435
 
        osutils.send_all(self.socket, bytes, self._report_activity)
436
 
        if 'hpss' in debug.debug_flags:
437
 
            thread_id = _thread.get_ident()
438
 
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
439
 
                         % ('wrote', thread_id, len(bytes),
440
 
                            osutils.perf_counter() - tstart))
 
128
        self.socket.sendall(bytes)
441
129
 
442
130
 
443
131
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
444
132
 
445
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
133
    def __init__(self, in_file, out_file, backing_transport):
446
134
        """Construct new server.
447
135
 
448
136
        :param in_file: Python file from which requests can be read.
449
137
        :param out_file: Python file to write responses.
450
138
        :param backing_transport: Transport for the directory served.
451
139
        """
452
 
        SmartServerStreamMedium.__init__(self, backing_transport,
453
 
                                         timeout=timeout)
 
140
        SmartServerStreamMedium.__init__(self, backing_transport)
454
141
        if sys.platform == 'win32':
455
142
            # force binary mode for files
456
143
            import msvcrt
461
148
        self._in = in_file
462
149
        self._out = out_file
463
150
 
464
 
    def serve(self):
465
 
        """See SmartServerStreamMedium.serve"""
466
 
        # This is the regular serve, except it adds signal trapping for soft
467
 
        # shutdown.
468
 
        stop_gracefully = self._stop_gracefully
469
 
        signals.register_on_hangup(id(self), stop_gracefully)
470
 
        try:
471
 
            return super(SmartServerPipeStreamMedium, self).serve()
472
 
        finally:
473
 
            signals.unregister_on_hangup(id(self))
474
 
 
475
151
    def _serve_one_request_unguarded(self, protocol):
476
152
        while True:
477
 
            # We need to be careful not to read past the end of the current
478
 
            # request, or else the read from the pipe will block, so we use
479
 
            # protocol.next_read_size().
480
153
            bytes_to_read = protocol.next_read_size()
481
154
            if bytes_to_read == 0:
482
155
                # Finished serving this request.
483
156
                self._out.flush()
484
157
                return
485
 
            bytes = self.read_bytes(bytes_to_read)
486
 
            if bytes == b'':
 
158
            bytes = self._in.read(bytes_to_read)
 
159
            if bytes == '':
487
160
                # Connection has been closed.
488
161
                self.finished = True
489
162
                self._out.flush()
490
163
                return
491
164
            protocol.accept_bytes(bytes)
492
165
 
493
 
    def _disconnect_client(self):
494
 
        self._in.close()
495
 
        self._out.flush()
496
 
        self._out.close()
497
 
 
498
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
499
 
        """Wait for more bytes to be read, but timeout if none available.
500
 
 
501
 
        This allows us to detect idle connections, and stop trying to read from
502
 
        them, without setting the socket itself to non-blocking. This also
503
 
        allows us to specify when we watch for idle timeouts.
504
 
 
505
 
        :return: None, this will raise ConnectionTimeout if we time out before
506
 
            data is available.
507
 
        """
508
 
        if (getattr(self._in, 'fileno', None) is None
509
 
                or sys.platform == 'win32'):
510
 
            # You can't select() file descriptors on Windows.
511
 
            return
512
 
        try:
513
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
514
 
        except io.UnsupportedOperation:
515
 
            return
516
 
 
517
 
    def _read_bytes(self, desired_count):
518
 
        return self._in.read(desired_count)
519
 
 
520
166
    def terminate_due_to_error(self):
521
167
        # TODO: This should log to a server log file, but no such thing
522
168
        # exists yet.  Andrew Bennetts 2006-09-29.
540
186
    request.finished_reading()
541
187
 
542
188
    It is up to the individual SmartClientMedium whether multiple concurrent
543
 
    requests can exist. See SmartClientMedium.get_request to obtain instances
544
 
    of SmartClientMediumRequest, and the concrete Medium you are using for
 
189
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
190
    of SmartClientMediumRequest, and the concrete Medium you are using for 
545
191
    details on concurrency and pipelining.
546
192
    """
547
193
 
556
202
    def accept_bytes(self, bytes):
557
203
        """Accept bytes for inclusion in this request.
558
204
 
559
 
        This method may not be called after finished_writing() has been
 
205
        This method may not be be called after finished_writing() has been
560
206
        called.  It depends upon the Medium whether or not the bytes will be
561
207
        immediately transmitted. Message based Mediums will tend to buffer the
562
208
        bytes until finished_writing() is called.
593
239
    def _finished_reading(self):
594
240
        """Helper for finished_reading.
595
241
 
596
 
        finished_reading checks the state of the request to determine if
 
242
        finished_reading checks the state of the request to determine if 
597
243
        finished_reading is allowed, and if it is hands off to _finished_reading
598
244
        to perform the action.
599
245
        """
613
259
    def _finished_writing(self):
614
260
        """Helper for finished_writing.
615
261
 
616
 
        finished_writing checks the state of the request to determine if
 
262
        finished_writing checks the state of the request to determine if 
617
263
        finished_writing is allowed, and if it is hands off to _finished_writing
618
264
        to perform the action.
619
265
        """
624
270
 
625
271
        This method will block and wait for count bytes to be read. It may not
626
272
        be invoked until finished_writing() has been called - this is to ensure
627
 
        a message-based approach to requests, for compatibility with message
 
273
        a message-based approach to requests, for compatability with message
628
274
        based mediums like HTTP.
629
275
        """
630
276
        if self._state == "writing":
634
280
        return self._read_bytes(count)
635
281
 
636
282
    def _read_bytes(self, count):
637
 
        """Helper for SmartClientMediumRequest.read_bytes.
 
283
        """Helper for read_bytes.
638
284
 
639
285
        read_bytes checks the state of the request to determing if bytes
640
286
        should be read. After that it hands off to _read_bytes to do the
641
287
        actual read.
642
 
 
643
 
        By default this forwards to self._medium.read_bytes because we are
644
 
        operating on the medium's stream.
645
 
        """
646
 
        return self._medium.read_bytes(count)
647
 
 
648
 
    def read_line(self):
649
 
        line = self._read_line()
650
 
        if not line.endswith(b'\n'):
651
 
            # end of file encountered reading from server
652
 
            raise errors.ConnectionReset(
653
 
                "Unexpected end of message. Please check connectivity "
654
 
                "and permissions, and report a bug if problems persist.")
655
 
        return line
656
 
 
657
 
    def _read_line(self):
658
 
        """Helper for SmartClientMediumRequest.read_line.
659
 
 
660
 
        By default this forwards to self._medium._get_line because we are
661
 
        operating on the medium's stream.
662
 
        """
663
 
        return self._medium._get_line()
664
 
 
665
 
 
666
 
class _VfsRefuser(object):
667
 
    """An object that refuses all VFS requests.
668
 
 
669
 
    """
670
 
 
671
 
    def __init__(self):
672
 
        client._SmartClient.hooks.install_named_hook(
673
 
            'call', self.check_vfs, 'vfs refuser')
674
 
 
675
 
    def check_vfs(self, params):
676
 
        try:
677
 
            request_method = request.request_handlers.get(params.method)
678
 
        except KeyError:
679
 
            # A method we don't know about doesn't count as a VFS method.
680
 
            return
681
 
        if issubclass(request_method, vfs.VfsRequest):
682
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
683
 
 
684
 
 
685
 
class _DebugCounter(object):
686
 
    """An object that counts the HPSS calls made to each client medium.
687
 
 
688
 
    When a medium is garbage-collected, or failing that when
689
 
    breezy.global_state exits, the total number of calls made on that medium
690
 
    are reported via trace.note.
691
 
    """
692
 
 
693
 
    def __init__(self):
694
 
        self.counts = weakref.WeakKeyDictionary()
695
 
        client._SmartClient.hooks.install_named_hook(
696
 
            'call', self.increment_call_count, 'hpss call counter')
697
 
        breezy.get_global_state().exit_stack.callback(self.flush_all)
698
 
 
699
 
    def track(self, medium):
700
 
        """Start tracking calls made to a medium.
701
 
 
702
 
        This only keeps a weakref to the medium, so shouldn't affect the
703
 
        medium's lifetime.
704
 
        """
705
 
        medium_repr = repr(medium)
706
 
        # Add this medium to the WeakKeyDictionary
707
 
        self.counts[medium] = dict(count=0, vfs_count=0,
708
 
                                   medium_repr=medium_repr)
709
 
        # Weakref callbacks are fired in reverse order of their association
710
 
        # with the referenced object.  So we add a weakref *after* adding to
711
 
        # the WeakKeyDict so that we can report the value from it before the
712
 
        # entry is removed by the WeakKeyDict's own callback.
713
 
        ref = weakref.ref(medium, self.done)
714
 
 
715
 
    def increment_call_count(self, params):
716
 
        # Increment the count in the WeakKeyDictionary
717
 
        value = self.counts[params.medium]
718
 
        value['count'] += 1
719
 
        try:
720
 
            request_method = request.request_handlers.get(params.method)
721
 
        except KeyError:
722
 
            # A method we don't know about doesn't count as a VFS method.
723
 
            return
724
 
        if issubclass(request_method, vfs.VfsRequest):
725
 
            value['vfs_count'] += 1
726
 
 
727
 
    def done(self, ref):
728
 
        value = self.counts[ref]
729
 
        count, vfs_count, medium_repr = (
730
 
            value['count'], value['vfs_count'], value['medium_repr'])
731
 
        # In case this callback is invoked for the same ref twice (by the
732
 
        # weakref callback and by the atexit function), set the call count back
733
 
        # to 0 so this item won't be reported twice.
734
 
        value['count'] = 0
735
 
        value['vfs_count'] = 0
736
 
        if count != 0:
737
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
738
 
                       count, vfs_count, medium_repr))
739
 
 
740
 
    def flush_all(self):
741
 
        for ref in list(self.counts.keys()):
742
 
            self.done(ref)
743
 
 
744
 
 
745
 
_debug_counter = None
746
 
_vfs_refuser = None
747
 
 
748
 
 
749
 
class SmartClientMedium(SmartMedium):
 
288
        """
 
289
        raise NotImplementedError(self._read_bytes)
 
290
 
 
291
 
 
292
class SmartClientMedium(object):
750
293
    """Smart client is a medium for sending smart protocol requests over."""
751
294
 
752
 
    def __init__(self, base):
753
 
        super(SmartClientMedium, self).__init__()
754
 
        self.base = base
755
 
        self._protocol_version_error = None
756
 
        self._protocol_version = None
757
 
        self._done_hello = False
758
 
        # Be optimistic: we assume the remote end can accept new remote
759
 
        # requests until we get an error saying otherwise.
760
 
        # _remote_version_is_before tracks the bzr version the remote side
761
 
        # can be based on what we've seen so far.
762
 
        self._remote_version_is_before = None
763
 
        # Install debug hook function if debug flag is set.
764
 
        if 'hpss' in debug.debug_flags:
765
 
            global _debug_counter
766
 
            if _debug_counter is None:
767
 
                _debug_counter = _DebugCounter()
768
 
            _debug_counter.track(self)
769
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
770
 
            global _vfs_refuser
771
 
            if _vfs_refuser is None:
772
 
                _vfs_refuser = _VfsRefuser()
773
 
 
774
 
    def _is_remote_before(self, version_tuple):
775
 
        """Is it possible the remote side supports RPCs for a given version?
776
 
 
777
 
        Typical use::
778
 
 
779
 
            needed_version = (1, 2)
780
 
            if medium._is_remote_before(needed_version):
781
 
                fallback_to_pre_1_2_rpc()
782
 
            else:
783
 
                try:
784
 
                    do_1_2_rpc()
785
 
                except UnknownSmartMethod:
786
 
                    medium._remember_remote_is_before(needed_version)
787
 
                    fallback_to_pre_1_2_rpc()
788
 
 
789
 
        :seealso: _remember_remote_is_before
790
 
        """
791
 
        if self._remote_version_is_before is None:
792
 
            # So far, the remote side seems to support everything
793
 
            return False
794
 
        return version_tuple >= self._remote_version_is_before
795
 
 
796
 
    def _remember_remote_is_before(self, version_tuple):
797
 
        """Tell this medium that the remote side is older the given version.
798
 
 
799
 
        :seealso: _is_remote_before
800
 
        """
801
 
        if (self._remote_version_is_before is not None and
802
 
                version_tuple > self._remote_version_is_before):
803
 
            # We have been told that the remote side is older than some version
804
 
            # which is newer than a previously supplied older-than version.
805
 
            # This indicates that some smart verb call is not guarded
806
 
            # appropriately (it should simply not have been tried).
807
 
            trace.mutter(
808
 
                "_remember_remote_is_before(%r) called, but "
809
 
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
810
 
            if 'hpss' in debug.debug_flags:
811
 
                ui.ui_factory.show_warning(
812
 
                    "_remember_remote_is_before(%r) called, but "
813
 
                    "_remember_remote_is_before(%r) was called previously."
814
 
                    % (version_tuple, self._remote_version_is_before))
815
 
            return
816
 
        self._remote_version_is_before = version_tuple
817
 
 
818
 
    def protocol_version(self):
819
 
        """Find out if 'hello' smart request works."""
820
 
        if self._protocol_version_error is not None:
821
 
            raise self._protocol_version_error
822
 
        if not self._done_hello:
823
 
            try:
824
 
                medium_request = self.get_request()
825
 
                # Send a 'hello' request in protocol version one, for maximum
826
 
                # backwards compatibility.
827
 
                client_protocol = protocol.SmartClientRequestProtocolOne(
828
 
                    medium_request)
829
 
                client_protocol.query_version()
830
 
                self._done_hello = True
831
 
            except errors.SmartProtocolError as e:
832
 
                # Cache the error, just like we would cache a successful
833
 
                # result.
834
 
                self._protocol_version_error = e
835
 
                raise
836
 
        return '2'
837
 
 
838
 
    def should_probe(self):
839
 
        """Should RemoteBzrDirFormat.probe_transport send a smart request on
840
 
        this medium?
841
 
 
842
 
        Some transports are unambiguously smart-only; there's no need to check
843
 
        if the transport is able to carry smart requests, because that's all
844
 
        it is for.  In those cases, this method should return False.
845
 
 
846
 
        But some HTTP transports can sometimes fail to carry smart requests,
847
 
        but still be usuable for accessing remote bzrdirs via plain file
848
 
        accesses.  So for those transports, their media should return True here
849
 
        so that RemoteBzrDirFormat can determine if it is appropriate for that
850
 
        transport.
851
 
        """
852
 
        return False
853
 
 
854
295
    def disconnect(self):
855
296
        """If this medium maintains a persistent connection, close it.
856
 
 
 
297
        
857
298
        The default implementation does nothing.
858
299
        """
859
 
 
860
 
    def remote_path_from_transport(self, transport):
861
 
        """Convert transport into a path suitable for using in a request.
862
 
 
863
 
        Note that the resulting remote path doesn't encode the host name or
864
 
        anything but path, so it is only safe to use it in requests sent over
865
 
        the medium from the matching transport.
866
 
        """
867
 
        medium_base = urlutils.join(self.base, '/')
868
 
        rel_url = urlutils.relative_url(medium_base, transport.base)
869
 
        return urlutils.unquote(rel_url)
870
 
 
 
300
        
871
301
 
872
302
class SmartClientStreamMedium(SmartClientMedium):
873
303
    """Stream based medium common class.
878
308
    receive bytes.
879
309
    """
880
310
 
881
 
    def __init__(self, base):
882
 
        SmartClientMedium.__init__(self, base)
 
311
    def __init__(self):
883
312
        self._current_request = None
884
313
 
885
314
    def accept_bytes(self, bytes):
893
322
 
894
323
    def _flush(self):
895
324
        """Flush the output stream.
896
 
 
 
325
        
897
326
        This method is used by the SmartClientStreamMediumRequest to ensure that
898
327
        all data for a request is sent, to avoid long timeouts or deadlocks.
899
328
        """
907
336
        """
908
337
        return SmartClientStreamMediumRequest(self)
909
338
 
910
 
    def reset(self):
911
 
        """We have been disconnected, reset current state.
912
 
 
913
 
        This resets things like _current_request and connected state.
914
 
        """
915
 
        self.disconnect()
916
 
        self._current_request = None
 
339
    def read_bytes(self, count):
 
340
        return self._read_bytes(count)
917
341
 
918
342
 
919
343
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
920
344
    """A client medium using simple pipes.
921
 
 
 
345
    
922
346
    This client does not manage the pipes: it assumes they will always be open.
923
347
    """
924
348
 
925
 
    def __init__(self, readable_pipe, writeable_pipe, base):
926
 
        SmartClientStreamMedium.__init__(self, base)
 
349
    def __init__(self, readable_pipe, writeable_pipe):
 
350
        SmartClientStreamMedium.__init__(self)
927
351
        self._readable_pipe = readable_pipe
928
352
        self._writeable_pipe = writeable_pipe
929
353
 
930
 
    def _accept_bytes(self, data):
 
354
    def _accept_bytes(self, bytes):
931
355
        """See SmartClientStreamMedium.accept_bytes."""
932
 
        try:
933
 
            self._writeable_pipe.write(data)
934
 
        except IOError as e:
935
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
936
 
                raise errors.ConnectionReset(
937
 
                    "Error trying to write to subprocess", e)
938
 
            raise
939
 
        self._report_activity(len(data), 'write')
 
356
        self._writeable_pipe.write(bytes)
940
357
 
941
358
    def _flush(self):
942
359
        """See SmartClientStreamMedium._flush()."""
943
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
944
 
        #       However, testing shows that even when the child process is
945
 
        #       gone, this doesn't error.
946
360
        self._writeable_pipe.flush()
947
361
 
948
362
    def _read_bytes(self, count):
949
363
        """See SmartClientStreamMedium._read_bytes."""
950
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
951
 
        data = self._readable_pipe.read(bytes_to_read)
952
 
        self._report_activity(len(data), 'read')
953
 
        return data
954
 
 
955
 
 
956
 
class SSHParams(object):
957
 
    """A set of parameters for starting a remote bzr via SSH."""
958
 
 
 
364
        return self._readable_pipe.read(count)
 
365
 
 
366
 
 
367
class SmartSSHClientMedium(SmartClientStreamMedium):
 
368
    """A client medium using SSH."""
 
369
    
959
370
    def __init__(self, host, port=None, username=None, password=None,
960
 
                 bzr_remote_path='bzr'):
961
 
        self.host = host
962
 
        self.port = port
963
 
        self.username = username
964
 
        self.password = password
965
 
        self.bzr_remote_path = bzr_remote_path
966
 
 
967
 
 
968
 
class SmartSSHClientMedium(SmartClientStreamMedium):
969
 
    """A client medium using SSH.
970
 
 
971
 
    It delegates IO to a SmartSimplePipesClientMedium or
972
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
973
 
    """
974
 
 
975
 
    def __init__(self, base, ssh_params, vendor=None):
 
371
            vendor=None):
976
372
        """Creates a client that will connect on the first use.
977
 
 
978
 
        :param ssh_params: A SSHParams instance.
 
373
        
979
374
        :param vendor: An optional override for the ssh vendor to use. See
980
 
            breezy.transport.ssh for details on ssh vendors.
 
375
            bzrlib.transport.ssh for details on ssh vendors.
981
376
        """
982
 
        self._real_medium = None
983
 
        self._ssh_params = ssh_params
984
 
        # for the benefit of progress making a short description of this
985
 
        # transport
986
 
        self._scheme = 'bzr+ssh'
987
 
        # SmartClientStreamMedium stores the repr of this object in its
988
 
        # _DebugCounter so we have to store all the values used in our repr
989
 
        # method before calling the super init.
990
 
        SmartClientStreamMedium.__init__(self, base)
 
377
        SmartClientStreamMedium.__init__(self)
 
378
        self._connected = False
 
379
        self._host = host
 
380
        self._password = password
 
381
        self._port = port
 
382
        self._username = username
 
383
        self._read_from = None
 
384
        self._ssh_connection = None
991
385
        self._vendor = vendor
992
 
        self._ssh_connection = None
993
 
 
994
 
    def __repr__(self):
995
 
        if self._ssh_params.port is None:
996
 
            maybe_port = ''
997
 
        else:
998
 
            maybe_port = ':%s' % self._ssh_params.port
999
 
        if self._ssh_params.username is None:
1000
 
            maybe_user = ''
1001
 
        else:
1002
 
            maybe_user = '%s@' % self._ssh_params.username
1003
 
        return "%s(%s://%s%s%s/)" % (
1004
 
            self.__class__.__name__,
1005
 
            self._scheme,
1006
 
            maybe_user,
1007
 
            self._ssh_params.host,
1008
 
            maybe_port)
 
386
        self._write_to = None
1009
387
 
1010
388
    def _accept_bytes(self, bytes):
1011
389
        """See SmartClientStreamMedium.accept_bytes."""
1012
390
        self._ensure_connection()
1013
 
        self._real_medium.accept_bytes(bytes)
 
391
        self._write_to.write(bytes)
1014
392
 
1015
393
    def disconnect(self):
1016
394
        """See SmartClientMedium.disconnect()."""
1017
 
        if self._real_medium is not None:
1018
 
            self._real_medium.disconnect()
1019
 
            self._real_medium = None
1020
 
        if self._ssh_connection is not None:
1021
 
            self._ssh_connection.close()
1022
 
            self._ssh_connection = None
 
395
        if not self._connected:
 
396
            return
 
397
        self._read_from.close()
 
398
        self._write_to.close()
 
399
        self._ssh_connection.close()
 
400
        self._connected = False
1023
401
 
1024
402
    def _ensure_connection(self):
1025
403
        """Connect this medium if not already connected."""
1026
 
        if self._real_medium is not None:
 
404
        if self._connected:
1027
405
            return
 
406
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
1028
407
        if self._vendor is None:
1029
408
            vendor = ssh._get_ssh_vendor()
1030
409
        else:
1031
410
            vendor = self._vendor
1032
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1033
 
                                                  self._ssh_params.password, self._ssh_params.host,
1034
 
                                                  self._ssh_params.port,
1035
 
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1036
 
                                                           '--directory=/', '--allow-writes'])
1037
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1038
 
        if io_kind == 'socket':
1039
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1040
 
                self.base, io_object)
1041
 
        elif io_kind == 'pipes':
1042
 
            read_from, write_to = io_object
1043
 
            self._real_medium = SmartSimplePipesClientMedium(
1044
 
                read_from, write_to, self.base)
1045
 
        else:
1046
 
            raise AssertionError(
1047
 
                "Unexpected io_kind %r from %r"
1048
 
                % (io_kind, self._ssh_connection))
1049
 
        for hook in transport.Transport.hooks["post_connect"]:
1050
 
            hook(self)
 
411
        self._ssh_connection = vendor.connect_ssh(self._username,
 
412
                self._password, self._host, self._port,
 
413
                command=[executable, 'serve', '--inet', '--directory=/',
 
414
                         '--allow-writes'])
 
415
        self._read_from, self._write_to = \
 
416
            self._ssh_connection.get_filelike_channels()
 
417
        self._connected = True
1051
418
 
1052
419
    def _flush(self):
1053
420
        """See SmartClientStreamMedium._flush()."""
1054
 
        self._real_medium._flush()
 
421
        self._write_to.flush()
1055
422
 
1056
423
    def _read_bytes(self, count):
1057
424
        """See SmartClientStreamMedium.read_bytes."""
1058
 
        if self._real_medium is None:
 
425
        if not self._connected:
1059
426
            raise errors.MediumNotConnected(self)
1060
 
        return self._real_medium.read_bytes(count)
1061
 
 
1062
 
 
1063
 
# Port 4155 is the default port for bzr://, registered with IANA.
1064
 
BZR_DEFAULT_INTERFACE = None
1065
 
BZR_DEFAULT_PORT = 4155
1066
 
 
1067
 
 
1068
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1069
 
    """A client medium using a socket.
1070
 
 
1071
 
    This class isn't usable directly.  Use one of its subclasses instead.
1072
 
    """
1073
 
 
1074
 
    def __init__(self, base):
1075
 
        SmartClientStreamMedium.__init__(self, base)
 
427
        return self._read_from.read(count)
 
428
 
 
429
 
 
430
class SmartTCPClientMedium(SmartClientStreamMedium):
 
431
    """A client medium using TCP."""
 
432
    
 
433
    def __init__(self, host, port):
 
434
        """Creates a client that will connect on the first use."""
 
435
        SmartClientStreamMedium.__init__(self)
 
436
        self._connected = False
 
437
        self._host = host
 
438
        self._port = port
1076
439
        self._socket = None
1077
 
        self._connected = False
1078
440
 
1079
441
    def _accept_bytes(self, bytes):
1080
442
        """See SmartClientMedium.accept_bytes."""
1081
443
        self._ensure_connection()
1082
 
        osutils.send_all(self._socket, bytes, self._report_activity)
 
444
        self._socket.sendall(bytes)
 
445
 
 
446
    def disconnect(self):
 
447
        """See SmartClientMedium.disconnect()."""
 
448
        if not self._connected:
 
449
            return
 
450
        self._socket.close()
 
451
        self._socket = None
 
452
        self._connected = False
1083
453
 
1084
454
    def _ensure_connection(self):
1085
455
        """Connect this medium if not already connected."""
1086
 
        raise NotImplementedError(self._ensure_connection)
 
456
        if self._connected:
 
457
            return
 
458
        self._socket = socket.socket()
 
459
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
460
        result = self._socket.connect_ex((self._host, int(self._port)))
 
461
        if result:
 
462
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
463
                    (self._host, self._port, os.strerror(result)))
 
464
        self._connected = True
1087
465
 
1088
466
    def _flush(self):
1089
467
        """See SmartClientStreamMedium._flush().
1090
 
 
1091
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1092
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1093
 
        future.
 
468
        
 
469
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
470
        add a means to do a flush, but that can be done in the future.
1094
471
        """
1095
472
 
1096
473
    def _read_bytes(self, count):
1097
474
        """See SmartClientMedium.read_bytes."""
1098
475
        if not self._connected:
1099
476
            raise errors.MediumNotConnected(self)
1100
 
        return osutils.read_bytes_from_socket(
1101
 
            self._socket, self._report_activity)
1102
 
 
1103
 
    def disconnect(self):
1104
 
        """See SmartClientMedium.disconnect()."""
1105
 
        if not self._connected:
1106
 
            return
1107
 
        self._socket.close()
1108
 
        self._socket = None
1109
 
        self._connected = False
1110
 
 
1111
 
 
1112
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1113
 
    """A client medium that creates a TCP connection."""
1114
 
 
1115
 
    def __init__(self, host, port, base):
1116
 
        """Creates a client that will connect on the first use."""
1117
 
        SmartClientSocketMedium.__init__(self, base)
1118
 
        self._host = host
1119
 
        self._port = port
1120
 
 
1121
 
    def _ensure_connection(self):
1122
 
        """Connect this medium if not already connected."""
1123
 
        if self._connected:
1124
 
            return
1125
 
        if self._port is None:
1126
 
            port = BZR_DEFAULT_PORT
1127
 
        else:
1128
 
            port = int(self._port)
1129
 
        try:
1130
 
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1131
 
                                           socket.SOCK_STREAM, 0, 0)
1132
 
        except socket.gaierror as xxx_todo_changeme:
1133
 
            (err_num, err_msg) = xxx_todo_changeme.args
1134
 
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1135
 
                                         (self._host, port, err_msg))
1136
 
        # Initialize err in case there are no addresses returned:
1137
 
        last_err = socket.error("no address found for %s" % self._host)
1138
 
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1139
 
            try:
1140
 
                self._socket = socket.socket(family, socktype, proto)
1141
 
                self._socket.setsockopt(socket.IPPROTO_TCP,
1142
 
                                        socket.TCP_NODELAY, 1)
1143
 
                self._socket.connect(sockaddr)
1144
 
            except socket.error as err:
1145
 
                if self._socket is not None:
1146
 
                    self._socket.close()
1147
 
                self._socket = None
1148
 
                last_err = err
1149
 
                continue
1150
 
            break
1151
 
        if self._socket is None:
1152
 
            # socket errors either have a (string) or (errno, string) as their
1153
 
            # args.
1154
 
            if isinstance(last_err.args, str):
1155
 
                err_msg = last_err.args
1156
 
            else:
1157
 
                err_msg = last_err.args[1]
1158
 
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1159
 
                                         (self._host, port, err_msg))
1160
 
        self._connected = True
1161
 
        for hook in transport.Transport.hooks["post_connect"]:
1162
 
            hook(self)
1163
 
 
1164
 
 
1165
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1166
 
    """A client medium for an already connected socket.
1167
 
 
1168
 
    Note that this class will assume it "owns" the socket, so it will close it
1169
 
    when its disconnect method is called.
1170
 
    """
1171
 
 
1172
 
    def __init__(self, base, sock):
1173
 
        SmartClientSocketMedium.__init__(self, base)
1174
 
        self._socket = sock
1175
 
        self._connected = True
1176
 
 
1177
 
    def _ensure_connection(self):
1178
 
        # Already connected, by definition!  So nothing to do.
1179
 
        pass
 
477
        return self._socket.recv(count)
1180
478
 
1181
479
 
1182
480
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1195
493
 
1196
494
    def _accept_bytes(self, bytes):
1197
495
        """See SmartClientMediumRequest._accept_bytes.
1198
 
 
 
496
        
1199
497
        This forwards to self._medium._accept_bytes because we are operating
1200
498
        on the mediums stream.
1201
499
        """
1204
502
    def _finished_reading(self):
1205
503
        """See SmartClientMediumRequest._finished_reading.
1206
504
 
1207
 
        This clears the _current_request on self._medium to allow a new
 
505
        This clears the _current_request on self._medium to allow a new 
1208
506
        request to be created.
1209
507
        """
1210
 
        if self._medium._current_request is not self:
1211
 
            raise AssertionError()
 
508
        assert self._medium._current_request is self
1212
509
        self._medium._current_request = None
1213
 
 
 
510
        
1214
511
    def _finished_writing(self):
1215
512
        """See SmartClientMediumRequest._finished_writing.
1216
513
 
1217
514
        This invokes self._medium._flush to ensure all bytes are transmitted.
1218
515
        """
1219
516
        self._medium._flush()
 
517
 
 
518
    def _read_bytes(self, count):
 
519
        """See SmartClientMediumRequest._read_bytes.
 
520
        
 
521
        This forwards to self._medium._read_bytes because we are operating
 
522
        on the mediums stream.
 
523
        """
 
524
        return self._medium._read_bytes(count)
 
525
 
 
526