/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Robert Collins
  • Date: 2010-05-06 23:41:35 UTC
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100506234135-yivbzczw1sejxnxc
Lock methods on ``Tree``, ``Branch`` and ``Repository`` are now
expected to return an object which can be used to unlock them. This reduces
duplicate code when using cleanups. The previous 'tokens's returned by
``Branch.lock_write`` and ``Repository.lock_write`` are now attributes
on the result of the lock_write. ``repository.RepositoryWriteLockResult``
and ``branch.BranchWriteLockResult`` document this. (Robert Collins)

``log._get_info_for_log_files`` now takes an add_cleanup callable.
(Robert Collins)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
 
import errno
30
 
import io
31
27
import os
32
28
import sys
33
 
import time
34
 
 
35
 
try:
36
 
    import _thread
37
 
except ImportError:
38
 
    import thread as _thread
39
 
 
40
 
import breezy
41
 
from ...lazy_import import lazy_import
 
29
import urllib
 
30
 
 
31
from bzrlib.lazy_import import lazy_import
42
32
lazy_import(globals(), """
43
 
import select
 
33
import atexit
44
34
import socket
 
35
import thread
45
36
import weakref
46
37
 
47
 
from breezy import (
 
38
from bzrlib import (
48
39
    debug,
 
40
    errors,
 
41
    symbol_versioning,
49
42
    trace,
50
 
    transport,
51
43
    ui,
52
44
    urlutils,
53
45
    )
54
 
from breezy.i18n import gettext
55
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
56
 
from breezy.transport import ssh
 
46
from bzrlib.smart import client, protocol, request, vfs
 
47
from bzrlib.transport import ssh
57
48
""")
58
 
from ... import (
59
 
    errors,
60
 
    osutils,
61
 
    )
 
49
from bzrlib import osutils
62
50
 
63
51
# Throughout this module buffer size parameters are either limited to be at
64
52
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
66
54
# from non-sockets as well.
67
55
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
68
56
 
69
 
 
70
 
class HpssVfsRequestNotAllowed(errors.BzrError):
71
 
 
72
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
73
 
            "%(method)s, %(arguments)s.")
74
 
 
75
 
    def __init__(self, method, arguments):
76
 
        self.method = method
77
 
        self.arguments = arguments
78
 
 
79
 
 
80
57
def _get_protocol_factory_for_bytes(bytes):
81
58
    """Determine the right protocol factory for 'bytes'.
82
59
 
118
95
    :returns: a tuple of two strs: (line, excess)
119
96
    """
120
97
    newline_pos = -1
121
 
    bytes = b''
 
98
    bytes = ''
122
99
    while newline_pos == -1:
123
100
        new_bytes = read_bytes_func(1)
124
101
        bytes += new_bytes
125
 
        if new_bytes == b'':
 
102
        if new_bytes == '':
126
103
            # Ran out of bytes before receiving a complete line.
127
 
            return bytes, b''
128
 
        newline_pos = bytes.find(b'\n')
129
 
    line = bytes[:newline_pos + 1]
130
 
    excess = bytes[newline_pos + 1:]
 
104
            return bytes, ''
 
105
        newline_pos = bytes.find('\n')
 
106
    line = bytes[:newline_pos+1]
 
107
    excess = bytes[newline_pos+1:]
131
108
    return line, excess
132
109
 
133
110
 
137
114
    def __init__(self):
138
115
        self._push_back_buffer = None
139
116
 
140
 
    def _push_back(self, data):
 
117
    def _push_back(self, bytes):
141
118
        """Return unused bytes to the medium, because they belong to the next
142
119
        request(s).
143
120
 
144
121
        This sets the _push_back_buffer to the given bytes.
145
122
        """
146
 
        if not isinstance(data, bytes):
147
 
            raise TypeError(data)
148
123
        if self._push_back_buffer is not None:
149
124
            raise AssertionError(
150
125
                "_push_back called when self._push_back_buffer is %r"
151
126
                % (self._push_back_buffer,))
152
 
        if data == b'':
 
127
        if bytes == '':
153
128
            return
154
 
        self._push_back_buffer = data
 
129
        self._push_back_buffer = bytes
155
130
 
156
131
    def _get_push_back_buffer(self):
157
 
        if self._push_back_buffer == b'':
 
132
        if self._push_back_buffer == '':
158
133
            raise AssertionError(
159
134
                '%s._push_back_buffer should never be the empty string, '
160
135
                'which can be confused with EOF' % (self,))
201
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
202
177
 
203
178
 
204
 
_bad_file_descriptor = (errno.EBADF,)
205
 
if sys.platform == 'win32':
206
 
    # Given on Windows if you pass a closed socket to select.select. Probably
207
 
    # also given if you pass a file handle to select.
208
 
    WSAENOTSOCK = 10038
209
 
    _bad_file_descriptor += (WSAENOTSOCK,)
210
 
 
211
 
 
212
179
class SmartServerStreamMedium(SmartMedium):
213
180
    """Handles smart commands coming over a stream.
214
181
 
227
194
        the stream.  See also the _push_back method.
228
195
    """
229
196
 
230
 
    _timer = time.time
231
 
 
232
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
233
198
        """Construct new server.
234
199
 
235
200
        :param backing_transport: Transport for the directory served.
238
203
        self.backing_transport = backing_transport
239
204
        self.root_client_path = root_client_path
240
205
        self.finished = False
241
 
        if timeout is None:
242
 
            raise AssertionError('You must supply a timeout.')
243
 
        self._client_timeout = timeout
244
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
245
206
        SmartMedium.__init__(self)
246
207
 
247
208
    def serve(self):
253
214
            while not self.finished:
254
215
                server_protocol = self._build_protocol()
255
216
                self._serve_one_request(server_protocol)
256
 
        except errors.ConnectionTimeout as e:
257
 
            trace.note('%s' % (e,))
258
 
            trace.log_exception_quietly()
259
 
            self._disconnect_client()
260
 
            # We reported it, no reason to make a big fuss.
261
 
            return
262
 
        except Exception as e:
 
217
        except Exception, e:
263
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
264
219
            raise
265
 
        self._disconnect_client()
266
 
 
267
 
    def _stop_gracefully(self):
268
 
        """When we finish this message, stop looking for more."""
269
 
        trace.mutter('Stopping %s' % (self,))
270
 
        self.finished = True
271
 
 
272
 
    def _disconnect_client(self):
273
 
        """Close the current connection. We stopped due to a timeout/etc."""
274
 
        # The default implementation is a no-op, because that is all we used to
275
 
        # do when disconnecting from a client. I suppose we never had the
276
 
        # *server* initiate a disconnect, before
277
 
 
278
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
279
 
        """Wait for more bytes to be read, but timeout if none available.
280
 
 
281
 
        This allows us to detect idle connections, and stop trying to read from
282
 
        them, without setting the socket itself to non-blocking. This also
283
 
        allows us to specify when we watch for idle timeouts.
284
 
 
285
 
        :return: Did we timeout? (True if we timed out, False if there is data
286
 
            to be read)
287
 
        """
288
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
289
220
 
290
221
    def _build_protocol(self):
291
222
        """Identifies the version of the incoming request, and returns an
296
227
 
297
228
        :returns: a SmartServerRequestProtocol.
298
229
        """
299
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
300
 
        if self.finished:
301
 
            # We're stopping, so don't try to do any more work
302
 
            return None
303
230
        bytes = self._get_line()
304
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
305
232
        protocol = protocol_factory(
307
234
        protocol.accept_bytes(unused_bytes)
308
235
        return protocol
309
236
 
310
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
311
 
        """select() on a file descriptor, waiting for nonblocking read()
312
 
 
313
 
        This will raise a ConnectionTimeout exception if we do not get a
314
 
        readable handle before timeout_seconds.
315
 
        :return: None
316
 
        """
317
 
        t_end = self._timer() + timeout_seconds
318
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
319
 
        rs = xs = None
320
 
        while not rs and not xs and self._timer() < t_end:
321
 
            if self.finished:
322
 
                return
323
 
            try:
324
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
325
 
            except (select.error, socket.error) as e:
326
 
                err = getattr(e, 'errno', None)
327
 
                if err is None and getattr(e, 'args', None) is not None:
328
 
                    # select.error doesn't have 'errno', it just has args[0]
329
 
                    err = e.args[0]
330
 
                if err in _bad_file_descriptor:
331
 
                    return  # Not a socket indicates read() will fail
332
 
                elif err == errno.EINTR:
333
 
                    # Interrupted, keep looping.
334
 
                    continue
335
 
                raise
336
 
            except ValueError:
337
 
                return  # Socket may already be closed
338
 
        if rs or xs:
339
 
            return
340
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
341
 
                                       % (timeout_seconds,))
342
 
 
343
237
    def _serve_one_request(self, protocol):
344
238
        """Read one request from input, process, send back a response.
345
239
 
346
240
        :param protocol: a SmartServerRequestProtocol.
347
241
        """
348
 
        if protocol is None:
349
 
            return
350
242
        try:
351
243
            self._serve_one_request_unguarded(protocol)
352
244
        except KeyboardInterrupt:
353
245
            raise
354
 
        except Exception as e:
 
246
        except Exception, e:
355
247
            self.terminate_due_to_error()
356
248
 
357
249
    def terminate_due_to_error(self):
368
260
 
369
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
370
262
 
371
 
    def __init__(self, sock, backing_transport, root_client_path='/',
372
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
373
264
        """Constructor.
374
265
 
375
266
        :param sock: the socket the server will read from.  It will be put
376
267
            into blocking mode.
377
268
        """
378
269
        SmartServerStreamMedium.__init__(
379
 
            self, backing_transport, root_client_path=root_client_path,
380
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
381
271
        sock.setblocking(True)
382
272
        self.socket = sock
383
 
        # Get the getpeername now, as we might be closed later when we care.
384
 
        try:
385
 
            self._client_info = sock.getpeername()
386
 
        except socket.error:
387
 
            self._client_info = '<unknown>'
388
 
 
389
 
    def __str__(self):
390
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
391
 
 
392
 
    def __repr__(self):
393
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
394
 
                                     self._client_info)
395
273
 
396
274
    def _serve_one_request_unguarded(self, protocol):
397
275
        while protocol.next_read_size():
399
277
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
400
278
            # short read immediately rather than block.
401
279
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
402
 
            if bytes == b'':
 
280
            if bytes == '':
403
281
                self.finished = True
404
282
                return
405
283
            protocol.accept_bytes(bytes)
406
284
 
407
285
        self._push_back(protocol.unused_data)
408
286
 
409
 
    def _disconnect_client(self):
410
 
        """Close the current connection. We stopped due to a timeout/etc."""
411
 
        self.socket.close()
412
 
 
413
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
414
 
        """Wait for more bytes to be read, but timeout if none available.
415
 
 
416
 
        This allows us to detect idle connections, and stop trying to read from
417
 
        them, without setting the socket itself to non-blocking. This also
418
 
        allows us to specify when we watch for idle timeouts.
419
 
 
420
 
        :return: None, this will raise ConnectionTimeout if we time out before
421
 
            data is available.
422
 
        """
423
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
424
 
 
425
287
    def _read_bytes(self, desired_count):
426
288
        return osutils.read_bytes_from_socket(
427
289
            self.socket, self._report_activity)
433
295
        self.finished = True
434
296
 
435
297
    def _write_out(self, bytes):
436
 
        tstart = osutils.perf_counter()
 
298
        tstart = osutils.timer_func()
437
299
        osutils.send_all(self.socket, bytes, self._report_activity)
438
300
        if 'hpss' in debug.debug_flags:
439
 
            thread_id = _thread.get_ident()
 
301
            thread_id = thread.get_ident()
440
302
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
441
303
                         % ('wrote', thread_id, len(bytes),
442
 
                            osutils.perf_counter() - tstart))
 
304
                            osutils.timer_func() - tstart))
443
305
 
444
306
 
445
307
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
446
308
 
447
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
309
    def __init__(self, in_file, out_file, backing_transport):
448
310
        """Construct new server.
449
311
 
450
312
        :param in_file: Python file from which requests can be read.
451
313
        :param out_file: Python file to write responses.
452
314
        :param backing_transport: Transport for the directory served.
453
315
        """
454
 
        SmartServerStreamMedium.__init__(self, backing_transport,
455
 
                                         timeout=timeout)
 
316
        SmartServerStreamMedium.__init__(self, backing_transport)
456
317
        if sys.platform == 'win32':
457
318
            # force binary mode for files
458
319
            import msvcrt
463
324
        self._in = in_file
464
325
        self._out = out_file
465
326
 
466
 
    def serve(self):
467
 
        """See SmartServerStreamMedium.serve"""
468
 
        # This is the regular serve, except it adds signal trapping for soft
469
 
        # shutdown.
470
 
        stop_gracefully = self._stop_gracefully
471
 
        signals.register_on_hangup(id(self), stop_gracefully)
472
 
        try:
473
 
            return super(SmartServerPipeStreamMedium, self).serve()
474
 
        finally:
475
 
            signals.unregister_on_hangup(id(self))
476
 
 
477
327
    def _serve_one_request_unguarded(self, protocol):
478
328
        while True:
479
329
            # We need to be careful not to read past the end of the current
485
335
                self._out.flush()
486
336
                return
487
337
            bytes = self.read_bytes(bytes_to_read)
488
 
            if bytes == b'':
 
338
            if bytes == '':
489
339
                # Connection has been closed.
490
340
                self.finished = True
491
341
                self._out.flush()
492
342
                return
493
343
            protocol.accept_bytes(bytes)
494
344
 
495
 
    def _disconnect_client(self):
496
 
        self._in.close()
497
 
        self._out.flush()
498
 
        self._out.close()
499
 
 
500
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
501
 
        """Wait for more bytes to be read, but timeout if none available.
502
 
 
503
 
        This allows us to detect idle connections, and stop trying to read from
504
 
        them, without setting the socket itself to non-blocking. This also
505
 
        allows us to specify when we watch for idle timeouts.
506
 
 
507
 
        :return: None, this will raise ConnectionTimeout if we time out before
508
 
            data is available.
509
 
        """
510
 
        if (getattr(self._in, 'fileno', None) is None
511
 
                or sys.platform == 'win32'):
512
 
            # You can't select() file descriptors on Windows.
513
 
            return
514
 
        try:
515
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
516
 
        except io.UnsupportedOperation:
517
 
            return
518
 
 
519
345
    def _read_bytes(self, desired_count):
520
346
        return self._in.read(desired_count)
521
347
 
649
475
 
650
476
    def read_line(self):
651
477
        line = self._read_line()
652
 
        if not line.endswith(b'\n'):
 
478
        if not line.endswith('\n'):
653
479
            # end of file encountered reading from server
654
480
            raise errors.ConnectionReset(
655
481
                "Unexpected end of message. Please check connectivity "
665
491
        return self._medium._get_line()
666
492
 
667
493
 
668
 
class _VfsRefuser(object):
669
 
    """An object that refuses all VFS requests.
670
 
 
671
 
    """
672
 
 
673
 
    def __init__(self):
674
 
        client._SmartClient.hooks.install_named_hook(
675
 
            'call', self.check_vfs, 'vfs refuser')
676
 
 
677
 
    def check_vfs(self, params):
678
 
        try:
679
 
            request_method = request.request_handlers.get(params.method)
680
 
        except KeyError:
681
 
            # A method we don't know about doesn't count as a VFS method.
682
 
            return
683
 
        if issubclass(request_method, vfs.VfsRequest):
684
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
685
 
 
686
 
 
687
494
class _DebugCounter(object):
688
495
    """An object that counts the HPSS calls made to each client medium.
689
496
 
690
 
    When a medium is garbage-collected, or failing that when
691
 
    breezy.global_state exits, the total number of calls made on that medium
692
 
    are reported via trace.note.
 
497
    When a medium is garbage-collected, or failing that when atexit functions
 
498
    are run, the total number of calls made on that medium are reported via
 
499
    trace.note.
693
500
    """
694
501
 
695
502
    def __init__(self):
696
503
        self.counts = weakref.WeakKeyDictionary()
697
504
        client._SmartClient.hooks.install_named_hook(
698
505
            'call', self.increment_call_count, 'hpss call counter')
699
 
        breezy.get_global_state().exit_stack.callback(self.flush_all)
 
506
        atexit.register(self.flush_all)
700
507
 
701
508
    def track(self, medium):
702
509
        """Start tracking calls made to a medium.
736
543
        value['count'] = 0
737
544
        value['vfs_count'] = 0
738
545
        if count != 0:
739
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
740
 
                       count, vfs_count, medium_repr))
 
546
            trace.note('HPSS calls: %d (%d vfs) %s',
 
547
                       count, vfs_count, medium_repr)
741
548
 
742
549
    def flush_all(self):
743
550
        for ref in list(self.counts.keys()):
744
551
            self.done(ref)
745
552
 
746
 
 
747
553
_debug_counter = None
748
 
_vfs_refuser = None
749
554
 
750
555
 
751
556
class SmartClientMedium(SmartMedium):
768
573
            if _debug_counter is None:
769
574
                _debug_counter = _DebugCounter()
770
575
            _debug_counter.track(self)
771
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
772
 
            global _vfs_refuser
773
 
            if _vfs_refuser is None:
774
 
                _vfs_refuser = _VfsRefuser()
775
576
 
776
577
    def _is_remote_before(self, version_tuple):
777
578
        """Is it possible the remote side supports RPCs for a given version?
801
602
        :seealso: _is_remote_before
802
603
        """
803
604
        if (self._remote_version_is_before is not None and
804
 
                version_tuple > self._remote_version_is_before):
 
605
            version_tuple > self._remote_version_is_before):
805
606
            # We have been told that the remote side is older than some version
806
607
            # which is newer than a previously supplied older-than version.
807
608
            # This indicates that some smart verb call is not guarded
808
609
            # appropriately (it should simply not have been tried).
809
610
            trace.mutter(
810
611
                "_remember_remote_is_before(%r) called, but "
811
 
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
 
612
                "_remember_remote_is_before(%r) was called previously."
 
613
                , version_tuple, self._remote_version_is_before)
812
614
            if 'hpss' in debug.debug_flags:
813
615
                ui.ui_factory.show_warning(
814
616
                    "_remember_remote_is_before(%r) called, but "
826
628
                medium_request = self.get_request()
827
629
                # Send a 'hello' request in protocol version one, for maximum
828
630
                # backwards compatibility.
829
 
                client_protocol = protocol.SmartClientRequestProtocolOne(
830
 
                    medium_request)
 
631
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
831
632
                client_protocol.query_version()
832
633
                self._done_hello = True
833
 
            except errors.SmartProtocolError as e:
 
634
            except errors.SmartProtocolError, e:
834
635
                # Cache the error, just like we would cache a successful
835
636
                # result.
836
637
                self._protocol_version_error = e
868
669
        """
869
670
        medium_base = urlutils.join(self.base, '/')
870
671
        rel_url = urlutils.relative_url(medium_base, transport.base)
871
 
        return urlutils.unquote(rel_url)
 
672
        return urllib.unquote(rel_url)
872
673
 
873
674
 
874
675
class SmartClientStreamMedium(SmartClientMedium):
909
710
        """
910
711
        return SmartClientStreamMediumRequest(self)
911
712
 
912
 
    def reset(self):
913
 
        """We have been disconnected, reset current state.
914
 
 
915
 
        This resets things like _current_request and connected state.
916
 
        """
917
 
        self.disconnect()
918
 
        self._current_request = None
919
 
 
920
713
 
921
714
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
922
715
    """A client medium using simple pipes.
923
716
 
924
717
    This client does not manage the pipes: it assumes they will always be open.
 
718
 
 
719
    Note that if readable_pipe.read might raise IOError or OSError with errno
 
720
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
 
721
    (such as used for sys.stdin) are safe.
925
722
    """
926
723
 
927
724
    def __init__(self, readable_pipe, writeable_pipe, base):
929
726
        self._readable_pipe = readable_pipe
930
727
        self._writeable_pipe = writeable_pipe
931
728
 
932
 
    def _accept_bytes(self, data):
 
729
    def _accept_bytes(self, bytes):
933
730
        """See SmartClientStreamMedium.accept_bytes."""
934
 
        try:
935
 
            self._writeable_pipe.write(data)
936
 
        except IOError as e:
937
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
938
 
                raise errors.ConnectionReset(
939
 
                    "Error trying to write to subprocess", e)
940
 
            raise
941
 
        self._report_activity(len(data), 'write')
 
731
        self._writeable_pipe.write(bytes)
 
732
        self._report_activity(len(bytes), 'write')
942
733
 
943
734
    def _flush(self):
944
735
        """See SmartClientStreamMedium._flush()."""
945
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
946
 
        #       However, testing shows that even when the child process is
947
 
        #       gone, this doesn't error.
948
736
        self._writeable_pipe.flush()
949
737
 
950
738
    def _read_bytes(self, count):
951
739
        """See SmartClientStreamMedium._read_bytes."""
952
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
953
 
        data = self._readable_pipe.read(bytes_to_read)
954
 
        self._report_activity(len(data), 'read')
955
 
        return data
956
 
 
957
 
 
958
 
class SSHParams(object):
959
 
    """A set of parameters for starting a remote bzr via SSH."""
 
740
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
 
741
        self._report_activity(len(bytes), 'read')
 
742
        return bytes
 
743
 
 
744
 
 
745
class SmartSSHClientMedium(SmartClientStreamMedium):
 
746
    """A client medium using SSH."""
960
747
 
961
748
    def __init__(self, host, port=None, username=None, password=None,
962
 
                 bzr_remote_path='bzr'):
963
 
        self.host = host
964
 
        self.port = port
965
 
        self.username = username
966
 
        self.password = password
967
 
        self.bzr_remote_path = bzr_remote_path
968
 
 
969
 
 
970
 
class SmartSSHClientMedium(SmartClientStreamMedium):
971
 
    """A client medium using SSH.
972
 
 
973
 
    It delegates IO to a SmartSimplePipesClientMedium or
974
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
975
 
    """
976
 
 
977
 
    def __init__(self, base, ssh_params, vendor=None):
 
749
            base=None, vendor=None, bzr_remote_path=None):
978
750
        """Creates a client that will connect on the first use.
979
751
 
980
 
        :param ssh_params: A SSHParams instance.
981
752
        :param vendor: An optional override for the ssh vendor to use. See
982
 
            breezy.transport.ssh for details on ssh vendors.
 
753
            bzrlib.transport.ssh for details on ssh vendors.
983
754
        """
984
 
        self._real_medium = None
985
 
        self._ssh_params = ssh_params
 
755
        self._connected = False
 
756
        self._host = host
 
757
        self._password = password
 
758
        self._port = port
 
759
        self._username = username
986
760
        # for the benefit of progress making a short description of this
987
761
        # transport
988
762
        self._scheme = 'bzr+ssh'
990
764
        # _DebugCounter so we have to store all the values used in our repr
991
765
        # method before calling the super init.
992
766
        SmartClientStreamMedium.__init__(self, base)
 
767
        self._read_from = None
 
768
        self._ssh_connection = None
993
769
        self._vendor = vendor
994
 
        self._ssh_connection = None
 
770
        self._write_to = None
 
771
        self._bzr_remote_path = bzr_remote_path
995
772
 
996
773
    def __repr__(self):
997
 
        if self._ssh_params.port is None:
 
774
        if self._port is None:
998
775
            maybe_port = ''
999
776
        else:
1000
 
            maybe_port = ':%s' % self._ssh_params.port
1001
 
        if self._ssh_params.username is None:
1002
 
            maybe_user = ''
1003
 
        else:
1004
 
            maybe_user = '%s@' % self._ssh_params.username
1005
 
        return "%s(%s://%s%s%s/)" % (
 
777
            maybe_port = ':%s' % self._port
 
778
        return "%s(%s://%s@%s%s/)" % (
1006
779
            self.__class__.__name__,
1007
780
            self._scheme,
1008
 
            maybe_user,
1009
 
            self._ssh_params.host,
 
781
            self._username,
 
782
            self._host,
1010
783
            maybe_port)
1011
784
 
1012
785
    def _accept_bytes(self, bytes):
1013
786
        """See SmartClientStreamMedium.accept_bytes."""
1014
787
        self._ensure_connection()
1015
 
        self._real_medium.accept_bytes(bytes)
 
788
        self._write_to.write(bytes)
 
789
        self._report_activity(len(bytes), 'write')
1016
790
 
1017
791
    def disconnect(self):
1018
792
        """See SmartClientMedium.disconnect()."""
1019
 
        if self._real_medium is not None:
1020
 
            self._real_medium.disconnect()
1021
 
            self._real_medium = None
1022
 
        if self._ssh_connection is not None:
1023
 
            self._ssh_connection.close()
1024
 
            self._ssh_connection = None
 
793
        if not self._connected:
 
794
            return
 
795
        self._read_from.close()
 
796
        self._write_to.close()
 
797
        self._ssh_connection.close()
 
798
        self._connected = False
1025
799
 
1026
800
    def _ensure_connection(self):
1027
801
        """Connect this medium if not already connected."""
1028
 
        if self._real_medium is not None:
 
802
        if self._connected:
1029
803
            return
1030
804
        if self._vendor is None:
1031
805
            vendor = ssh._get_ssh_vendor()
1032
806
        else:
1033
807
            vendor = self._vendor
1034
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1035
 
                                                  self._ssh_params.password, self._ssh_params.host,
1036
 
                                                  self._ssh_params.port,
1037
 
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1038
 
                                                           '--directory=/', '--allow-writes'])
1039
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1040
 
        if io_kind == 'socket':
1041
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1042
 
                self.base, io_object)
1043
 
        elif io_kind == 'pipes':
1044
 
            read_from, write_to = io_object
1045
 
            self._real_medium = SmartSimplePipesClientMedium(
1046
 
                read_from, write_to, self.base)
1047
 
        else:
1048
 
            raise AssertionError(
1049
 
                "Unexpected io_kind %r from %r"
1050
 
                % (io_kind, self._ssh_connection))
1051
 
        for hook in transport.Transport.hooks["post_connect"]:
1052
 
            hook(self)
 
808
        self._ssh_connection = vendor.connect_ssh(self._username,
 
809
                self._password, self._host, self._port,
 
810
                command=[self._bzr_remote_path, 'serve', '--inet',
 
811
                         '--directory=/', '--allow-writes'])
 
812
        self._read_from, self._write_to = \
 
813
            self._ssh_connection.get_filelike_channels()
 
814
        self._connected = True
1053
815
 
1054
816
    def _flush(self):
1055
817
        """See SmartClientStreamMedium._flush()."""
1056
 
        self._real_medium._flush()
 
818
        self._write_to.flush()
1057
819
 
1058
820
    def _read_bytes(self, count):
1059
821
        """See SmartClientStreamMedium.read_bytes."""
1060
 
        if self._real_medium is None:
 
822
        if not self._connected:
1061
823
            raise errors.MediumNotConnected(self)
1062
 
        return self._real_medium.read_bytes(count)
 
824
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
825
        bytes = self._read_from.read(bytes_to_read)
 
826
        self._report_activity(len(bytes), 'read')
 
827
        return bytes
1063
828
 
1064
829
 
1065
830
# Port 4155 is the default port for bzr://, registered with IANA.
1067
832
BZR_DEFAULT_PORT = 4155
1068
833
 
1069
834
 
1070
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1071
 
    """A client medium using a socket.
1072
 
 
1073
 
    This class isn't usable directly.  Use one of its subclasses instead.
1074
 
    """
1075
 
 
1076
 
    def __init__(self, base):
 
835
class SmartTCPClientMedium(SmartClientStreamMedium):
 
836
    """A client medium using TCP."""
 
837
 
 
838
    def __init__(self, host, port, base):
 
839
        """Creates a client that will connect on the first use."""
1077
840
        SmartClientStreamMedium.__init__(self, base)
 
841
        self._connected = False
 
842
        self._host = host
 
843
        self._port = port
1078
844
        self._socket = None
1079
 
        self._connected = False
1080
845
 
1081
846
    def _accept_bytes(self, bytes):
1082
847
        """See SmartClientMedium.accept_bytes."""
1083
848
        self._ensure_connection()
1084
849
        osutils.send_all(self._socket, bytes, self._report_activity)
1085
850
 
1086
 
    def _ensure_connection(self):
1087
 
        """Connect this medium if not already connected."""
1088
 
        raise NotImplementedError(self._ensure_connection)
1089
 
 
1090
 
    def _flush(self):
1091
 
        """See SmartClientStreamMedium._flush().
1092
 
 
1093
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1094
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1095
 
        future.
1096
 
        """
1097
 
 
1098
 
    def _read_bytes(self, count):
1099
 
        """See SmartClientMedium.read_bytes."""
1100
 
        if not self._connected:
1101
 
            raise errors.MediumNotConnected(self)
1102
 
        return osutils.read_bytes_from_socket(
1103
 
            self._socket, self._report_activity)
1104
 
 
1105
851
    def disconnect(self):
1106
852
        """See SmartClientMedium.disconnect()."""
1107
853
        if not self._connected:
1110
856
        self._socket = None
1111
857
        self._connected = False
1112
858
 
1113
 
 
1114
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1115
 
    """A client medium that creates a TCP connection."""
1116
 
 
1117
 
    def __init__(self, host, port, base):
1118
 
        """Creates a client that will connect on the first use."""
1119
 
        SmartClientSocketMedium.__init__(self, base)
1120
 
        self._host = host
1121
 
        self._port = port
1122
 
 
1123
859
    def _ensure_connection(self):
1124
860
        """Connect this medium if not already connected."""
1125
861
        if self._connected:
1130
866
            port = int(self._port)
1131
867
        try:
1132
868
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1133
 
                                           socket.SOCK_STREAM, 0, 0)
1134
 
        except socket.gaierror as xxx_todo_changeme:
1135
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
869
                socket.SOCK_STREAM, 0, 0)
 
870
        except socket.gaierror, (err_num, err_msg):
1136
871
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1137
 
                                         (self._host, port, err_msg))
 
872
                    (self._host, port, err_msg))
1138
873
        # Initialize err in case there are no addresses returned:
1139
 
        last_err = socket.error("no address found for %s" % self._host)
 
874
        err = socket.error("no address found for %s" % self._host)
1140
875
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1141
876
            try:
1142
877
                self._socket = socket.socket(family, socktype, proto)
1143
878
                self._socket.setsockopt(socket.IPPROTO_TCP,
1144
879
                                        socket.TCP_NODELAY, 1)
1145
880
                self._socket.connect(sockaddr)
1146
 
            except socket.error as err:
 
881
            except socket.error, err:
1147
882
                if self._socket is not None:
1148
883
                    self._socket.close()
1149
884
                self._socket = None
1150
 
                last_err = err
1151
885
                continue
1152
886
            break
1153
887
        if self._socket is None:
1154
888
            # socket errors either have a (string) or (errno, string) as their
1155
889
            # args.
1156
 
            if isinstance(last_err.args, str):
1157
 
                err_msg = last_err.args
 
890
            if type(err.args) is str:
 
891
                err_msg = err.args
1158
892
            else:
1159
 
                err_msg = last_err.args[1]
 
893
                err_msg = err.args[1]
1160
894
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1161
 
                                         (self._host, port, err_msg))
1162
 
        self._connected = True
1163
 
        for hook in transport.Transport.hooks["post_connect"]:
1164
 
            hook(self)
1165
 
 
1166
 
 
1167
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1168
 
    """A client medium for an already connected socket.
1169
 
 
1170
 
    Note that this class will assume it "owns" the socket, so it will close it
1171
 
    when its disconnect method is called.
1172
 
    """
1173
 
 
1174
 
    def __init__(self, base, sock):
1175
 
        SmartClientSocketMedium.__init__(self, base)
1176
 
        self._socket = sock
1177
 
        self._connected = True
1178
 
 
1179
 
    def _ensure_connection(self):
1180
 
        # Already connected, by definition!  So nothing to do.
1181
 
        pass
 
895
                    (self._host, port, err_msg))
 
896
        self._connected = True
 
897
 
 
898
    def _flush(self):
 
899
        """See SmartClientStreamMedium._flush().
 
900
 
 
901
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
902
        add a means to do a flush, but that can be done in the future.
 
903
        """
 
904
 
 
905
    def _read_bytes(self, count):
 
906
        """See SmartClientMedium.read_bytes."""
 
907
        if not self._connected:
 
908
            raise errors.MediumNotConnected(self)
 
909
        return osutils.read_bytes_from_socket(
 
910
            self._socket, self._report_activity)
1182
911
 
1183
912
 
1184
913
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1219
948
        This invokes self._medium._flush to ensure all bytes are transmitted.
1220
949
        """
1221
950
        self._medium._flush()
 
951
 
 
952