/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Vincent Ladeuil
  • Date: 2010-07-07 11:21:19 UTC
  • mto: (5193.7.1 unify-confs)
  • mto: This revision was merged to the branch mainline in revision 5349.
  • Revision ID: v.ladeuil+lp@free.fr-20100707112119-jwyh312df41w6l0o
Revert previous change as I can't reproduce the related problem anymore.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
import errno
28
 
import io
29
27
import os
30
28
import sys
31
 
import time
32
 
 
33
 
try:
34
 
    import _thread
35
 
except ImportError:
36
 
    import thread as _thread
37
 
 
38
 
import breezy
39
 
from ...lazy_import import lazy_import
 
29
import urllib
 
30
 
 
31
import bzrlib
 
32
from bzrlib.lazy_import import lazy_import
40
33
lazy_import(globals(), """
41
 
import select
42
34
import socket
 
35
import thread
43
36
import weakref
44
37
 
45
 
from breezy import (
 
38
from bzrlib import (
46
39
    debug,
 
40
    errors,
 
41
    symbol_versioning,
47
42
    trace,
48
 
    transport,
49
43
    ui,
50
44
    urlutils,
51
45
    )
52
 
from breezy.i18n import gettext
53
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
54
 
from breezy.transport import ssh
 
46
from bzrlib.smart import client, protocol, request, vfs
 
47
from bzrlib.transport import ssh
55
48
""")
56
 
from ... import (
57
 
    errors,
58
 
    osutils,
59
 
    )
 
49
from bzrlib import osutils
60
50
 
61
51
# Throughout this module buffer size parameters are either limited to be at
62
52
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
64
54
# from non-sockets as well.
65
55
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
66
56
 
67
 
 
68
 
class HpssVfsRequestNotAllowed(errors.BzrError):
69
 
 
70
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
71
 
            "%(method)s, %(arguments)s.")
72
 
 
73
 
    def __init__(self, method, arguments):
74
 
        self.method = method
75
 
        self.arguments = arguments
76
 
 
77
 
 
78
57
def _get_protocol_factory_for_bytes(bytes):
79
58
    """Determine the right protocol factory for 'bytes'.
80
59
 
116
95
    :returns: a tuple of two strs: (line, excess)
117
96
    """
118
97
    newline_pos = -1
119
 
    bytes = b''
 
98
    bytes = ''
120
99
    while newline_pos == -1:
121
100
        new_bytes = read_bytes_func(1)
122
101
        bytes += new_bytes
123
 
        if new_bytes == b'':
 
102
        if new_bytes == '':
124
103
            # Ran out of bytes before receiving a complete line.
125
 
            return bytes, b''
126
 
        newline_pos = bytes.find(b'\n')
127
 
    line = bytes[:newline_pos + 1]
128
 
    excess = bytes[newline_pos + 1:]
 
104
            return bytes, ''
 
105
        newline_pos = bytes.find('\n')
 
106
    line = bytes[:newline_pos+1]
 
107
    excess = bytes[newline_pos+1:]
129
108
    return line, excess
130
109
 
131
110
 
135
114
    def __init__(self):
136
115
        self._push_back_buffer = None
137
116
 
138
 
    def _push_back(self, data):
 
117
    def _push_back(self, bytes):
139
118
        """Return unused bytes to the medium, because they belong to the next
140
119
        request(s).
141
120
 
142
121
        This sets the _push_back_buffer to the given bytes.
143
122
        """
144
 
        if not isinstance(data, bytes):
145
 
            raise TypeError(data)
146
123
        if self._push_back_buffer is not None:
147
124
            raise AssertionError(
148
125
                "_push_back called when self._push_back_buffer is %r"
149
126
                % (self._push_back_buffer,))
150
 
        if data == b'':
 
127
        if bytes == '':
151
128
            return
152
 
        self._push_back_buffer = data
 
129
        self._push_back_buffer = bytes
153
130
 
154
131
    def _get_push_back_buffer(self):
155
 
        if self._push_back_buffer == b'':
 
132
        if self._push_back_buffer == '':
156
133
            raise AssertionError(
157
134
                '%s._push_back_buffer should never be the empty string, '
158
135
                'which can be confused with EOF' % (self,))
199
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
200
177
 
201
178
 
202
 
_bad_file_descriptor = (errno.EBADF,)
203
 
if sys.platform == 'win32':
204
 
    # Given on Windows if you pass a closed socket to select.select. Probably
205
 
    # also given if you pass a file handle to select.
206
 
    WSAENOTSOCK = 10038
207
 
    _bad_file_descriptor += (WSAENOTSOCK,)
208
 
 
209
 
 
210
179
class SmartServerStreamMedium(SmartMedium):
211
180
    """Handles smart commands coming over a stream.
212
181
 
225
194
        the stream.  See also the _push_back method.
226
195
    """
227
196
 
228
 
    _timer = time.time
229
 
 
230
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
231
198
        """Construct new server.
232
199
 
233
200
        :param backing_transport: Transport for the directory served.
236
203
        self.backing_transport = backing_transport
237
204
        self.root_client_path = root_client_path
238
205
        self.finished = False
239
 
        if timeout is None:
240
 
            raise AssertionError('You must supply a timeout.')
241
 
        self._client_timeout = timeout
242
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
243
206
        SmartMedium.__init__(self)
244
207
 
245
208
    def serve(self):
251
214
            while not self.finished:
252
215
                server_protocol = self._build_protocol()
253
216
                self._serve_one_request(server_protocol)
254
 
        except errors.ConnectionTimeout as e:
255
 
            trace.note('%s' % (e,))
256
 
            trace.log_exception_quietly()
257
 
            self._disconnect_client()
258
 
            # We reported it, no reason to make a big fuss.
259
 
            return
260
 
        except Exception as e:
 
217
        except Exception, e:
261
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
262
219
            raise
263
 
        self._disconnect_client()
264
 
 
265
 
    def _stop_gracefully(self):
266
 
        """When we finish this message, stop looking for more."""
267
 
        trace.mutter('Stopping %s' % (self,))
268
 
        self.finished = True
269
 
 
270
 
    def _disconnect_client(self):
271
 
        """Close the current connection. We stopped due to a timeout/etc."""
272
 
        # The default implementation is a no-op, because that is all we used to
273
 
        # do when disconnecting from a client. I suppose we never had the
274
 
        # *server* initiate a disconnect, before
275
 
 
276
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
277
 
        """Wait for more bytes to be read, but timeout if none available.
278
 
 
279
 
        This allows us to detect idle connections, and stop trying to read from
280
 
        them, without setting the socket itself to non-blocking. This also
281
 
        allows us to specify when we watch for idle timeouts.
282
 
 
283
 
        :return: Did we timeout? (True if we timed out, False if there is data
284
 
            to be read)
285
 
        """
286
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
287
220
 
288
221
    def _build_protocol(self):
289
222
        """Identifies the version of the incoming request, and returns an
294
227
 
295
228
        :returns: a SmartServerRequestProtocol.
296
229
        """
297
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
298
 
        if self.finished:
299
 
            # We're stopping, so don't try to do any more work
300
 
            return None
301
230
        bytes = self._get_line()
302
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
303
232
        protocol = protocol_factory(
305
234
        protocol.accept_bytes(unused_bytes)
306
235
        return protocol
307
236
 
308
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
309
 
        """select() on a file descriptor, waiting for nonblocking read()
310
 
 
311
 
        This will raise a ConnectionTimeout exception if we do not get a
312
 
        readable handle before timeout_seconds.
313
 
        :return: None
314
 
        """
315
 
        t_end = self._timer() + timeout_seconds
316
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
317
 
        rs = xs = None
318
 
        while not rs and not xs and self._timer() < t_end:
319
 
            if self.finished:
320
 
                return
321
 
            try:
322
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
323
 
            except (select.error, socket.error) as e:
324
 
                err = getattr(e, 'errno', None)
325
 
                if err is None and getattr(e, 'args', None) is not None:
326
 
                    # select.error doesn't have 'errno', it just has args[0]
327
 
                    err = e.args[0]
328
 
                if err in _bad_file_descriptor:
329
 
                    return  # Not a socket indicates read() will fail
330
 
                elif err == errno.EINTR:
331
 
                    # Interrupted, keep looping.
332
 
                    continue
333
 
                raise
334
 
            except ValueError:
335
 
                return  # Socket may already be closed
336
 
        if rs or xs:
337
 
            return
338
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
339
 
                                       % (timeout_seconds,))
340
 
 
341
237
    def _serve_one_request(self, protocol):
342
238
        """Read one request from input, process, send back a response.
343
239
 
344
240
        :param protocol: a SmartServerRequestProtocol.
345
241
        """
346
 
        if protocol is None:
347
 
            return
348
242
        try:
349
243
            self._serve_one_request_unguarded(protocol)
350
244
        except KeyboardInterrupt:
351
245
            raise
352
 
        except Exception as e:
 
246
        except Exception, e:
353
247
            self.terminate_due_to_error()
354
248
 
355
249
    def terminate_due_to_error(self):
366
260
 
367
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
368
262
 
369
 
    def __init__(self, sock, backing_transport, root_client_path='/',
370
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
371
264
        """Constructor.
372
265
 
373
266
        :param sock: the socket the server will read from.  It will be put
374
267
            into blocking mode.
375
268
        """
376
269
        SmartServerStreamMedium.__init__(
377
 
            self, backing_transport, root_client_path=root_client_path,
378
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
379
271
        sock.setblocking(True)
380
272
        self.socket = sock
381
 
        # Get the getpeername now, as we might be closed later when we care.
382
 
        try:
383
 
            self._client_info = sock.getpeername()
384
 
        except socket.error:
385
 
            self._client_info = '<unknown>'
386
 
 
387
 
    def __str__(self):
388
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
389
 
 
390
 
    def __repr__(self):
391
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
392
 
                                     self._client_info)
393
273
 
394
274
    def _serve_one_request_unguarded(self, protocol):
395
275
        while protocol.next_read_size():
397
277
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
398
278
            # short read immediately rather than block.
399
279
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
400
 
            if bytes == b'':
 
280
            if bytes == '':
401
281
                self.finished = True
402
282
                return
403
283
            protocol.accept_bytes(bytes)
404
284
 
405
285
        self._push_back(protocol.unused_data)
406
286
 
407
 
    def _disconnect_client(self):
408
 
        """Close the current connection. We stopped due to a timeout/etc."""
409
 
        self.socket.close()
410
 
 
411
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
412
 
        """Wait for more bytes to be read, but timeout if none available.
413
 
 
414
 
        This allows us to detect idle connections, and stop trying to read from
415
 
        them, without setting the socket itself to non-blocking. This also
416
 
        allows us to specify when we watch for idle timeouts.
417
 
 
418
 
        :return: None, this will raise ConnectionTimeout if we time out before
419
 
            data is available.
420
 
        """
421
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
422
 
 
423
287
    def _read_bytes(self, desired_count):
424
288
        return osutils.read_bytes_from_socket(
425
289
            self.socket, self._report_activity)
431
295
        self.finished = True
432
296
 
433
297
    def _write_out(self, bytes):
434
 
        tstart = osutils.perf_counter()
 
298
        tstart = osutils.timer_func()
435
299
        osutils.send_all(self.socket, bytes, self._report_activity)
436
300
        if 'hpss' in debug.debug_flags:
437
 
            thread_id = _thread.get_ident()
 
301
            thread_id = thread.get_ident()
438
302
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
439
303
                         % ('wrote', thread_id, len(bytes),
440
 
                            osutils.perf_counter() - tstart))
 
304
                            osutils.timer_func() - tstart))
441
305
 
442
306
 
443
307
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
444
308
 
445
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
309
    def __init__(self, in_file, out_file, backing_transport):
446
310
        """Construct new server.
447
311
 
448
312
        :param in_file: Python file from which requests can be read.
449
313
        :param out_file: Python file to write responses.
450
314
        :param backing_transport: Transport for the directory served.
451
315
        """
452
 
        SmartServerStreamMedium.__init__(self, backing_transport,
453
 
                                         timeout=timeout)
 
316
        SmartServerStreamMedium.__init__(self, backing_transport)
454
317
        if sys.platform == 'win32':
455
318
            # force binary mode for files
456
319
            import msvcrt
461
324
        self._in = in_file
462
325
        self._out = out_file
463
326
 
464
 
    def serve(self):
465
 
        """See SmartServerStreamMedium.serve"""
466
 
        # This is the regular serve, except it adds signal trapping for soft
467
 
        # shutdown.
468
 
        stop_gracefully = self._stop_gracefully
469
 
        signals.register_on_hangup(id(self), stop_gracefully)
470
 
        try:
471
 
            return super(SmartServerPipeStreamMedium, self).serve()
472
 
        finally:
473
 
            signals.unregister_on_hangup(id(self))
474
 
 
475
327
    def _serve_one_request_unguarded(self, protocol):
476
328
        while True:
477
329
            # We need to be careful not to read past the end of the current
483
335
                self._out.flush()
484
336
                return
485
337
            bytes = self.read_bytes(bytes_to_read)
486
 
            if bytes == b'':
 
338
            if bytes == '':
487
339
                # Connection has been closed.
488
340
                self.finished = True
489
341
                self._out.flush()
490
342
                return
491
343
            protocol.accept_bytes(bytes)
492
344
 
493
 
    def _disconnect_client(self):
494
 
        self._in.close()
495
 
        self._out.flush()
496
 
        self._out.close()
497
 
 
498
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
499
 
        """Wait for more bytes to be read, but timeout if none available.
500
 
 
501
 
        This allows us to detect idle connections, and stop trying to read from
502
 
        them, without setting the socket itself to non-blocking. This also
503
 
        allows us to specify when we watch for idle timeouts.
504
 
 
505
 
        :return: None, this will raise ConnectionTimeout if we time out before
506
 
            data is available.
507
 
        """
508
 
        if (getattr(self._in, 'fileno', None) is None
509
 
                or sys.platform == 'win32'):
510
 
            # You can't select() file descriptors on Windows.
511
 
            return
512
 
        try:
513
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
514
 
        except io.UnsupportedOperation:
515
 
            return
516
 
 
517
345
    def _read_bytes(self, desired_count):
518
346
        return self._in.read(desired_count)
519
347
 
647
475
 
648
476
    def read_line(self):
649
477
        line = self._read_line()
650
 
        if not line.endswith(b'\n'):
 
478
        if not line.endswith('\n'):
651
479
            # end of file encountered reading from server
652
480
            raise errors.ConnectionReset(
653
481
                "Unexpected end of message. Please check connectivity "
663
491
        return self._medium._get_line()
664
492
 
665
493
 
666
 
class _VfsRefuser(object):
667
 
    """An object that refuses all VFS requests.
668
 
 
669
 
    """
670
 
 
671
 
    def __init__(self):
672
 
        client._SmartClient.hooks.install_named_hook(
673
 
            'call', self.check_vfs, 'vfs refuser')
674
 
 
675
 
    def check_vfs(self, params):
676
 
        try:
677
 
            request_method = request.request_handlers.get(params.method)
678
 
        except KeyError:
679
 
            # A method we don't know about doesn't count as a VFS method.
680
 
            return
681
 
        if issubclass(request_method, vfs.VfsRequest):
682
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
683
 
 
684
 
 
685
494
class _DebugCounter(object):
686
495
    """An object that counts the HPSS calls made to each client medium.
687
496
 
688
497
    When a medium is garbage-collected, or failing that when
689
 
    breezy.global_state exits, the total number of calls made on that medium
 
498
    bzrlib.global_state exits, the total number of calls made on that medium
690
499
    are reported via trace.note.
691
500
    """
692
501
 
694
503
        self.counts = weakref.WeakKeyDictionary()
695
504
        client._SmartClient.hooks.install_named_hook(
696
505
            'call', self.increment_call_count, 'hpss call counter')
697
 
        breezy.get_global_state().exit_stack.callback(self.flush_all)
 
506
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
698
507
 
699
508
    def track(self, medium):
700
509
        """Start tracking calls made to a medium.
734
543
        value['count'] = 0
735
544
        value['vfs_count'] = 0
736
545
        if count != 0:
737
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
738
 
                       count, vfs_count, medium_repr))
 
546
            trace.note('HPSS calls: %d (%d vfs) %s',
 
547
                       count, vfs_count, medium_repr)
739
548
 
740
549
    def flush_all(self):
741
550
        for ref in list(self.counts.keys()):
742
551
            self.done(ref)
743
552
 
744
 
 
745
553
_debug_counter = None
746
 
_vfs_refuser = None
747
554
 
748
555
 
749
556
class SmartClientMedium(SmartMedium):
766
573
            if _debug_counter is None:
767
574
                _debug_counter = _DebugCounter()
768
575
            _debug_counter.track(self)
769
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
770
 
            global _vfs_refuser
771
 
            if _vfs_refuser is None:
772
 
                _vfs_refuser = _VfsRefuser()
773
576
 
774
577
    def _is_remote_before(self, version_tuple):
775
578
        """Is it possible the remote side supports RPCs for a given version?
799
602
        :seealso: _is_remote_before
800
603
        """
801
604
        if (self._remote_version_is_before is not None and
802
 
                version_tuple > self._remote_version_is_before):
 
605
            version_tuple > self._remote_version_is_before):
803
606
            # We have been told that the remote side is older than some version
804
607
            # which is newer than a previously supplied older-than version.
805
608
            # This indicates that some smart verb call is not guarded
806
609
            # appropriately (it should simply not have been tried).
807
610
            trace.mutter(
808
611
                "_remember_remote_is_before(%r) called, but "
809
 
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
 
612
                "_remember_remote_is_before(%r) was called previously."
 
613
                , version_tuple, self._remote_version_is_before)
810
614
            if 'hpss' in debug.debug_flags:
811
615
                ui.ui_factory.show_warning(
812
616
                    "_remember_remote_is_before(%r) called, but "
824
628
                medium_request = self.get_request()
825
629
                # Send a 'hello' request in protocol version one, for maximum
826
630
                # backwards compatibility.
827
 
                client_protocol = protocol.SmartClientRequestProtocolOne(
828
 
                    medium_request)
 
631
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
829
632
                client_protocol.query_version()
830
633
                self._done_hello = True
831
 
            except errors.SmartProtocolError as e:
 
634
            except errors.SmartProtocolError, e:
832
635
                # Cache the error, just like we would cache a successful
833
636
                # result.
834
637
                self._protocol_version_error = e
866
669
        """
867
670
        medium_base = urlutils.join(self.base, '/')
868
671
        rel_url = urlutils.relative_url(medium_base, transport.base)
869
 
        return urlutils.unquote(rel_url)
 
672
        return urllib.unquote(rel_url)
870
673
 
871
674
 
872
675
class SmartClientStreamMedium(SmartClientMedium):
907
710
        """
908
711
        return SmartClientStreamMediumRequest(self)
909
712
 
910
 
    def reset(self):
911
 
        """We have been disconnected, reset current state.
912
 
 
913
 
        This resets things like _current_request and connected state.
914
 
        """
915
 
        self.disconnect()
916
 
        self._current_request = None
917
 
 
918
713
 
919
714
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
920
715
    """A client medium using simple pipes.
927
722
        self._readable_pipe = readable_pipe
928
723
        self._writeable_pipe = writeable_pipe
929
724
 
930
 
    def _accept_bytes(self, data):
 
725
    def _accept_bytes(self, bytes):
931
726
        """See SmartClientStreamMedium.accept_bytes."""
932
 
        try:
933
 
            self._writeable_pipe.write(data)
934
 
        except IOError as e:
935
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
936
 
                raise errors.ConnectionReset(
937
 
                    "Error trying to write to subprocess", e)
938
 
            raise
939
 
        self._report_activity(len(data), 'write')
 
727
        self._writeable_pipe.write(bytes)
 
728
        self._report_activity(len(bytes), 'write')
940
729
 
941
730
    def _flush(self):
942
731
        """See SmartClientStreamMedium._flush()."""
943
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
944
 
        #       However, testing shows that even when the child process is
945
 
        #       gone, this doesn't error.
946
732
        self._writeable_pipe.flush()
947
733
 
948
734
    def _read_bytes(self, count):
949
735
        """See SmartClientStreamMedium._read_bytes."""
950
736
        bytes_to_read = min(count, _MAX_READ_SIZE)
951
 
        data = self._readable_pipe.read(bytes_to_read)
952
 
        self._report_activity(len(data), 'read')
953
 
        return data
 
737
        bytes = self._readable_pipe.read(bytes_to_read)
 
738
        self._report_activity(len(bytes), 'read')
 
739
        return bytes
954
740
 
955
741
 
956
742
class SSHParams(object):
957
743
    """A set of parameters for starting a remote bzr via SSH."""
958
744
 
959
745
    def __init__(self, host, port=None, username=None, password=None,
960
 
                 bzr_remote_path='bzr'):
 
746
            bzr_remote_path='bzr'):
961
747
        self.host = host
962
748
        self.port = port
963
749
        self.username = username
967
753
 
968
754
class SmartSSHClientMedium(SmartClientStreamMedium):
969
755
    """A client medium using SSH.
970
 
 
971
 
    It delegates IO to a SmartSimplePipesClientMedium or
 
756
    
 
757
    It delegates IO to a SmartClientSocketMedium or
972
758
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
973
759
    """
974
760
 
977
763
 
978
764
        :param ssh_params: A SSHParams instance.
979
765
        :param vendor: An optional override for the ssh vendor to use. See
980
 
            breezy.transport.ssh for details on ssh vendors.
 
766
            bzrlib.transport.ssh for details on ssh vendors.
981
767
        """
982
768
        self._real_medium = None
983
769
        self._ssh_params = ssh_params
996
782
            maybe_port = ''
997
783
        else:
998
784
            maybe_port = ':%s' % self._ssh_params.port
999
 
        if self._ssh_params.username is None:
1000
 
            maybe_user = ''
1001
 
        else:
1002
 
            maybe_user = '%s@' % self._ssh_params.username
1003
 
        return "%s(%s://%s%s%s/)" % (
 
785
        return "%s(%s://%s@%s%s/)" % (
1004
786
            self.__class__.__name__,
1005
787
            self._scheme,
1006
 
            maybe_user,
 
788
            self._ssh_params.username,
1007
789
            self._ssh_params.host,
1008
790
            maybe_port)
1009
791
 
1030
812
        else:
1031
813
            vendor = self._vendor
1032
814
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1033
 
                                                  self._ssh_params.password, self._ssh_params.host,
1034
 
                                                  self._ssh_params.port,
1035
 
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1036
 
                                                           '--directory=/', '--allow-writes'])
 
815
                self._ssh_params.password, self._ssh_params.host,
 
816
                self._ssh_params.port,
 
817
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
818
                         '--directory=/', '--allow-writes'])
1037
819
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1038
820
        if io_kind == 'socket':
1039
821
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1046
828
            raise AssertionError(
1047
829
                "Unexpected io_kind %r from %r"
1048
830
                % (io_kind, self._ssh_connection))
1049
 
        for hook in transport.Transport.hooks["post_connect"]:
1050
 
            hook(self)
1051
831
 
1052
832
    def _flush(self):
1053
833
        """See SmartClientStreamMedium._flush()."""
1067
847
 
1068
848
class SmartClientSocketMedium(SmartClientStreamMedium):
1069
849
    """A client medium using a socket.
1070
 
 
 
850
    
1071
851
    This class isn't usable directly.  Use one of its subclasses instead.
1072
852
    """
1073
853
 
1128
908
            port = int(self._port)
1129
909
        try:
1130
910
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1131
 
                                           socket.SOCK_STREAM, 0, 0)
1132
 
        except socket.gaierror as xxx_todo_changeme:
1133
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
911
                socket.SOCK_STREAM, 0, 0)
 
912
        except socket.gaierror, (err_num, err_msg):
1134
913
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1135
 
                                         (self._host, port, err_msg))
 
914
                    (self._host, port, err_msg))
1136
915
        # Initialize err in case there are no addresses returned:
1137
 
        last_err = socket.error("no address found for %s" % self._host)
 
916
        err = socket.error("no address found for %s" % self._host)
1138
917
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1139
918
            try:
1140
919
                self._socket = socket.socket(family, socktype, proto)
1141
920
                self._socket.setsockopt(socket.IPPROTO_TCP,
1142
921
                                        socket.TCP_NODELAY, 1)
1143
922
                self._socket.connect(sockaddr)
1144
 
            except socket.error as err:
 
923
            except socket.error, err:
1145
924
                if self._socket is not None:
1146
925
                    self._socket.close()
1147
926
                self._socket = None
1148
 
                last_err = err
1149
927
                continue
1150
928
            break
1151
929
        if self._socket is None:
1152
930
            # socket errors either have a (string) or (errno, string) as their
1153
931
            # args.
1154
 
            if isinstance(last_err.args, str):
1155
 
                err_msg = last_err.args
 
932
            if type(err.args) is str:
 
933
                err_msg = err.args
1156
934
            else:
1157
 
                err_msg = last_err.args[1]
 
935
                err_msg = err.args[1]
1158
936
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1159
 
                                         (self._host, port, err_msg))
 
937
                    (self._host, port, err_msg))
1160
938
        self._connected = True
1161
 
        for hook in transport.Transport.hooks["post_connect"]:
1162
 
            hook(self)
1163
939
 
1164
940
 
1165
941
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1166
942
    """A client medium for an already connected socket.
1167
 
 
 
943
    
1168
944
    Note that this class will assume it "owns" the socket, so it will close it
1169
945
    when its disconnect method is called.
1170
946
    """
1217
993
        This invokes self._medium._flush to ensure all bytes are transmitted.
1218
994
        """
1219
995
        self._medium._flush()
 
996
 
 
997