/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Marius Kruger
  • Date: 2010-07-10 21:28:56 UTC
  • mto: (5384.1.1 integration)
  • mto: This revision was merged to the branch mainline in revision 5385.
  • Revision ID: marius.kruger@enerweb.co.za-20100710212856-uq4ji3go0u5se7hx
* Update documentation
* add NEWS

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
 
import errno
30
 
import io
31
27
import os
32
28
import sys
33
 
import time
34
 
 
35
 
try:
36
 
    import _thread
37
 
except ImportError:
38
 
    import thread as _thread
39
 
 
40
 
import breezy
41
 
from ...lazy_import import lazy_import
 
29
import urllib
 
30
 
 
31
import bzrlib
 
32
from bzrlib.lazy_import import lazy_import
42
33
lazy_import(globals(), """
43
 
import select
44
34
import socket
 
35
import thread
45
36
import weakref
46
37
 
47
 
from breezy import (
 
38
from bzrlib import (
48
39
    debug,
 
40
    errors,
 
41
    symbol_versioning,
49
42
    trace,
50
 
    transport,
51
43
    ui,
52
44
    urlutils,
53
45
    )
54
 
from breezy.i18n import gettext
55
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
56
 
from breezy.transport import ssh
 
46
from bzrlib.smart import client, protocol, request, vfs
 
47
from bzrlib.transport import ssh
57
48
""")
58
 
from ... import (
59
 
    errors,
60
 
    osutils,
61
 
    )
 
49
from bzrlib import osutils
62
50
 
63
51
# Throughout this module buffer size parameters are either limited to be at
64
52
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
66
54
# from non-sockets as well.
67
55
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
68
56
 
69
 
 
70
 
class HpssVfsRequestNotAllowed(errors.BzrError):
71
 
 
72
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
73
 
            "%(method)s, %(arguments)s.")
74
 
 
75
 
    def __init__(self, method, arguments):
76
 
        self.method = method
77
 
        self.arguments = arguments
78
 
 
79
 
 
80
57
def _get_protocol_factory_for_bytes(bytes):
81
58
    """Determine the right protocol factory for 'bytes'.
82
59
 
118
95
    :returns: a tuple of two strs: (line, excess)
119
96
    """
120
97
    newline_pos = -1
121
 
    bytes = b''
 
98
    bytes = ''
122
99
    while newline_pos == -1:
123
100
        new_bytes = read_bytes_func(1)
124
101
        bytes += new_bytes
125
 
        if new_bytes == b'':
 
102
        if new_bytes == '':
126
103
            # Ran out of bytes before receiving a complete line.
127
 
            return bytes, b''
128
 
        newline_pos = bytes.find(b'\n')
129
 
    line = bytes[:newline_pos + 1]
130
 
    excess = bytes[newline_pos + 1:]
 
104
            return bytes, ''
 
105
        newline_pos = bytes.find('\n')
 
106
    line = bytes[:newline_pos+1]
 
107
    excess = bytes[newline_pos+1:]
131
108
    return line, excess
132
109
 
133
110
 
137
114
    def __init__(self):
138
115
        self._push_back_buffer = None
139
116
 
140
 
    def _push_back(self, data):
 
117
    def _push_back(self, bytes):
141
118
        """Return unused bytes to the medium, because they belong to the next
142
119
        request(s).
143
120
 
144
121
        This sets the _push_back_buffer to the given bytes.
145
122
        """
146
 
        if not isinstance(data, bytes):
147
 
            raise TypeError(data)
148
123
        if self._push_back_buffer is not None:
149
124
            raise AssertionError(
150
125
                "_push_back called when self._push_back_buffer is %r"
151
126
                % (self._push_back_buffer,))
152
 
        if data == b'':
 
127
        if bytes == '':
153
128
            return
154
 
        self._push_back_buffer = data
 
129
        self._push_back_buffer = bytes
155
130
 
156
131
    def _get_push_back_buffer(self):
157
 
        if self._push_back_buffer == b'':
 
132
        if self._push_back_buffer == '':
158
133
            raise AssertionError(
159
134
                '%s._push_back_buffer should never be the empty string, '
160
135
                'which can be confused with EOF' % (self,))
201
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
202
177
 
203
178
 
204
 
_bad_file_descriptor = (errno.EBADF,)
205
 
if sys.platform == 'win32':
206
 
    # Given on Windows if you pass a closed socket to select.select. Probably
207
 
    # also given if you pass a file handle to select.
208
 
    WSAENOTSOCK = 10038
209
 
    _bad_file_descriptor += (WSAENOTSOCK,)
210
 
 
211
 
 
212
179
class SmartServerStreamMedium(SmartMedium):
213
180
    """Handles smart commands coming over a stream.
214
181
 
227
194
        the stream.  See also the _push_back method.
228
195
    """
229
196
 
230
 
    _timer = time.time
231
 
 
232
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
233
198
        """Construct new server.
234
199
 
235
200
        :param backing_transport: Transport for the directory served.
238
203
        self.backing_transport = backing_transport
239
204
        self.root_client_path = root_client_path
240
205
        self.finished = False
241
 
        if timeout is None:
242
 
            raise AssertionError('You must supply a timeout.')
243
 
        self._client_timeout = timeout
244
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
245
206
        SmartMedium.__init__(self)
246
207
 
247
208
    def serve(self):
253
214
            while not self.finished:
254
215
                server_protocol = self._build_protocol()
255
216
                self._serve_one_request(server_protocol)
256
 
        except errors.ConnectionTimeout as e:
257
 
            trace.note('%s' % (e,))
258
 
            trace.log_exception_quietly()
259
 
            self._disconnect_client()
260
 
            # We reported it, no reason to make a big fuss.
261
 
            return
262
 
        except Exception as e:
 
217
        except Exception, e:
263
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
264
219
            raise
265
 
        self._disconnect_client()
266
 
 
267
 
    def _stop_gracefully(self):
268
 
        """When we finish this message, stop looking for more."""
269
 
        trace.mutter('Stopping %s' % (self,))
270
 
        self.finished = True
271
 
 
272
 
    def _disconnect_client(self):
273
 
        """Close the current connection. We stopped due to a timeout/etc."""
274
 
        # The default implementation is a no-op, because that is all we used to
275
 
        # do when disconnecting from a client. I suppose we never had the
276
 
        # *server* initiate a disconnect, before
277
 
 
278
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
279
 
        """Wait for more bytes to be read, but timeout if none available.
280
 
 
281
 
        This allows us to detect idle connections, and stop trying to read from
282
 
        them, without setting the socket itself to non-blocking. This also
283
 
        allows us to specify when we watch for idle timeouts.
284
 
 
285
 
        :return: Did we timeout? (True if we timed out, False if there is data
286
 
            to be read)
287
 
        """
288
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
289
220
 
290
221
    def _build_protocol(self):
291
222
        """Identifies the version of the incoming request, and returns an
296
227
 
297
228
        :returns: a SmartServerRequestProtocol.
298
229
        """
299
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
300
 
        if self.finished:
301
 
            # We're stopping, so don't try to do any more work
302
 
            return None
303
230
        bytes = self._get_line()
304
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
305
232
        protocol = protocol_factory(
307
234
        protocol.accept_bytes(unused_bytes)
308
235
        return protocol
309
236
 
310
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
311
 
        """select() on a file descriptor, waiting for nonblocking read()
312
 
 
313
 
        This will raise a ConnectionTimeout exception if we do not get a
314
 
        readable handle before timeout_seconds.
315
 
        :return: None
316
 
        """
317
 
        t_end = self._timer() + timeout_seconds
318
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
319
 
        rs = xs = None
320
 
        while not rs and not xs and self._timer() < t_end:
321
 
            if self.finished:
322
 
                return
323
 
            try:
324
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
325
 
            except (select.error, socket.error) as e:
326
 
                err = getattr(e, 'errno', None)
327
 
                if err is None and getattr(e, 'args', None) is not None:
328
 
                    # select.error doesn't have 'errno', it just has args[0]
329
 
                    err = e.args[0]
330
 
                if err in _bad_file_descriptor:
331
 
                    return  # Not a socket indicates read() will fail
332
 
                elif err == errno.EINTR:
333
 
                    # Interrupted, keep looping.
334
 
                    continue
335
 
                raise
336
 
            except ValueError:
337
 
                return  # Socket may already be closed
338
 
        if rs or xs:
339
 
            return
340
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
341
 
                                       % (timeout_seconds,))
342
 
 
343
237
    def _serve_one_request(self, protocol):
344
238
        """Read one request from input, process, send back a response.
345
239
 
346
240
        :param protocol: a SmartServerRequestProtocol.
347
241
        """
348
 
        if protocol is None:
349
 
            return
350
242
        try:
351
243
            self._serve_one_request_unguarded(protocol)
352
244
        except KeyboardInterrupt:
353
245
            raise
354
 
        except Exception as e:
 
246
        except Exception, e:
355
247
            self.terminate_due_to_error()
356
248
 
357
249
    def terminate_due_to_error(self):
368
260
 
369
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
370
262
 
371
 
    def __init__(self, sock, backing_transport, root_client_path='/',
372
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
373
264
        """Constructor.
374
265
 
375
266
        :param sock: the socket the server will read from.  It will be put
376
267
            into blocking mode.
377
268
        """
378
269
        SmartServerStreamMedium.__init__(
379
 
            self, backing_transport, root_client_path=root_client_path,
380
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
381
271
        sock.setblocking(True)
382
272
        self.socket = sock
383
 
        # Get the getpeername now, as we might be closed later when we care.
384
 
        try:
385
 
            self._client_info = sock.getpeername()
386
 
        except socket.error:
387
 
            self._client_info = '<unknown>'
388
 
 
389
 
    def __str__(self):
390
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
391
 
 
392
 
    def __repr__(self):
393
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
394
 
                                     self._client_info)
395
273
 
396
274
    def _serve_one_request_unguarded(self, protocol):
397
275
        while protocol.next_read_size():
399
277
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
400
278
            # short read immediately rather than block.
401
279
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
402
 
            if bytes == b'':
 
280
            if bytes == '':
403
281
                self.finished = True
404
282
                return
405
283
            protocol.accept_bytes(bytes)
406
284
 
407
285
        self._push_back(protocol.unused_data)
408
286
 
409
 
    def _disconnect_client(self):
410
 
        """Close the current connection. We stopped due to a timeout/etc."""
411
 
        self.socket.close()
412
 
 
413
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
414
 
        """Wait for more bytes to be read, but timeout if none available.
415
 
 
416
 
        This allows us to detect idle connections, and stop trying to read from
417
 
        them, without setting the socket itself to non-blocking. This also
418
 
        allows us to specify when we watch for idle timeouts.
419
 
 
420
 
        :return: None, this will raise ConnectionTimeout if we time out before
421
 
            data is available.
422
 
        """
423
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
424
 
 
425
287
    def _read_bytes(self, desired_count):
426
288
        return osutils.read_bytes_from_socket(
427
289
            self.socket, self._report_activity)
433
295
        self.finished = True
434
296
 
435
297
    def _write_out(self, bytes):
436
 
        tstart = osutils.perf_counter()
 
298
        tstart = osutils.timer_func()
437
299
        osutils.send_all(self.socket, bytes, self._report_activity)
438
300
        if 'hpss' in debug.debug_flags:
439
 
            thread_id = _thread.get_ident()
 
301
            thread_id = thread.get_ident()
440
302
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
441
303
                         % ('wrote', thread_id, len(bytes),
442
 
                            osutils.perf_counter() - tstart))
 
304
                            osutils.timer_func() - tstart))
443
305
 
444
306
 
445
307
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
446
308
 
447
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
309
    def __init__(self, in_file, out_file, backing_transport):
448
310
        """Construct new server.
449
311
 
450
312
        :param in_file: Python file from which requests can be read.
451
313
        :param out_file: Python file to write responses.
452
314
        :param backing_transport: Transport for the directory served.
453
315
        """
454
 
        SmartServerStreamMedium.__init__(self, backing_transport,
455
 
                                         timeout=timeout)
 
316
        SmartServerStreamMedium.__init__(self, backing_transport)
456
317
        if sys.platform == 'win32':
457
318
            # force binary mode for files
458
319
            import msvcrt
463
324
        self._in = in_file
464
325
        self._out = out_file
465
326
 
466
 
    def serve(self):
467
 
        """See SmartServerStreamMedium.serve"""
468
 
        # This is the regular serve, except it adds signal trapping for soft
469
 
        # shutdown.
470
 
        stop_gracefully = self._stop_gracefully
471
 
        signals.register_on_hangup(id(self), stop_gracefully)
472
 
        try:
473
 
            return super(SmartServerPipeStreamMedium, self).serve()
474
 
        finally:
475
 
            signals.unregister_on_hangup(id(self))
476
 
 
477
327
    def _serve_one_request_unguarded(self, protocol):
478
328
        while True:
479
329
            # We need to be careful not to read past the end of the current
485
335
                self._out.flush()
486
336
                return
487
337
            bytes = self.read_bytes(bytes_to_read)
488
 
            if bytes == b'':
 
338
            if bytes == '':
489
339
                # Connection has been closed.
490
340
                self.finished = True
491
341
                self._out.flush()
492
342
                return
493
343
            protocol.accept_bytes(bytes)
494
344
 
495
 
    def _disconnect_client(self):
496
 
        self._in.close()
497
 
        self._out.flush()
498
 
        self._out.close()
499
 
 
500
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
501
 
        """Wait for more bytes to be read, but timeout if none available.
502
 
 
503
 
        This allows us to detect idle connections, and stop trying to read from
504
 
        them, without setting the socket itself to non-blocking. This also
505
 
        allows us to specify when we watch for idle timeouts.
506
 
 
507
 
        :return: None, this will raise ConnectionTimeout if we time out before
508
 
            data is available.
509
 
        """
510
 
        if (getattr(self._in, 'fileno', None) is None
511
 
                or sys.platform == 'win32'):
512
 
            # You can't select() file descriptors on Windows.
513
 
            return
514
 
        try:
515
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
516
 
        except io.UnsupportedOperation:
517
 
            return
518
 
 
519
345
    def _read_bytes(self, desired_count):
520
346
        return self._in.read(desired_count)
521
347
 
649
475
 
650
476
    def read_line(self):
651
477
        line = self._read_line()
652
 
        if not line.endswith(b'\n'):
 
478
        if not line.endswith('\n'):
653
479
            # end of file encountered reading from server
654
480
            raise errors.ConnectionReset(
655
481
                "Unexpected end of message. Please check connectivity "
665
491
        return self._medium._get_line()
666
492
 
667
493
 
668
 
class _VfsRefuser(object):
669
 
    """An object that refuses all VFS requests.
670
 
 
671
 
    """
672
 
 
673
 
    def __init__(self):
674
 
        client._SmartClient.hooks.install_named_hook(
675
 
            'call', self.check_vfs, 'vfs refuser')
676
 
 
677
 
    def check_vfs(self, params):
678
 
        try:
679
 
            request_method = request.request_handlers.get(params.method)
680
 
        except KeyError:
681
 
            # A method we don't know about doesn't count as a VFS method.
682
 
            return
683
 
        if issubclass(request_method, vfs.VfsRequest):
684
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
685
 
 
686
 
 
687
494
class _DebugCounter(object):
688
495
    """An object that counts the HPSS calls made to each client medium.
689
496
 
690
497
    When a medium is garbage-collected, or failing that when
691
 
    breezy.global_state exits, the total number of calls made on that medium
 
498
    bzrlib.global_state exits, the total number of calls made on that medium
692
499
    are reported via trace.note.
693
500
    """
694
501
 
696
503
        self.counts = weakref.WeakKeyDictionary()
697
504
        client._SmartClient.hooks.install_named_hook(
698
505
            'call', self.increment_call_count, 'hpss call counter')
699
 
        breezy.get_global_state().exit_stack.callback(self.flush_all)
 
506
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
700
507
 
701
508
    def track(self, medium):
702
509
        """Start tracking calls made to a medium.
736
543
        value['count'] = 0
737
544
        value['vfs_count'] = 0
738
545
        if count != 0:
739
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
740
 
                       count, vfs_count, medium_repr))
 
546
            trace.note('HPSS calls: %d (%d vfs) %s',
 
547
                       count, vfs_count, medium_repr)
741
548
 
742
549
    def flush_all(self):
743
550
        for ref in list(self.counts.keys()):
744
551
            self.done(ref)
745
552
 
746
 
 
747
553
_debug_counter = None
748
 
_vfs_refuser = None
749
554
 
750
555
 
751
556
class SmartClientMedium(SmartMedium):
768
573
            if _debug_counter is None:
769
574
                _debug_counter = _DebugCounter()
770
575
            _debug_counter.track(self)
771
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
772
 
            global _vfs_refuser
773
 
            if _vfs_refuser is None:
774
 
                _vfs_refuser = _VfsRefuser()
775
576
 
776
577
    def _is_remote_before(self, version_tuple):
777
578
        """Is it possible the remote side supports RPCs for a given version?
801
602
        :seealso: _is_remote_before
802
603
        """
803
604
        if (self._remote_version_is_before is not None and
804
 
                version_tuple > self._remote_version_is_before):
 
605
            version_tuple > self._remote_version_is_before):
805
606
            # We have been told that the remote side is older than some version
806
607
            # which is newer than a previously supplied older-than version.
807
608
            # This indicates that some smart verb call is not guarded
808
609
            # appropriately (it should simply not have been tried).
809
610
            trace.mutter(
810
611
                "_remember_remote_is_before(%r) called, but "
811
 
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
 
612
                "_remember_remote_is_before(%r) was called previously."
 
613
                , version_tuple, self._remote_version_is_before)
812
614
            if 'hpss' in debug.debug_flags:
813
615
                ui.ui_factory.show_warning(
814
616
                    "_remember_remote_is_before(%r) called, but "
826
628
                medium_request = self.get_request()
827
629
                # Send a 'hello' request in protocol version one, for maximum
828
630
                # backwards compatibility.
829
 
                client_protocol = protocol.SmartClientRequestProtocolOne(
830
 
                    medium_request)
 
631
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
831
632
                client_protocol.query_version()
832
633
                self._done_hello = True
833
 
            except errors.SmartProtocolError as e:
 
634
            except errors.SmartProtocolError, e:
834
635
                # Cache the error, just like we would cache a successful
835
636
                # result.
836
637
                self._protocol_version_error = e
868
669
        """
869
670
        medium_base = urlutils.join(self.base, '/')
870
671
        rel_url = urlutils.relative_url(medium_base, transport.base)
871
 
        return urlutils.unquote(rel_url)
 
672
        return urllib.unquote(rel_url)
872
673
 
873
674
 
874
675
class SmartClientStreamMedium(SmartClientMedium):
909
710
        """
910
711
        return SmartClientStreamMediumRequest(self)
911
712
 
912
 
    def reset(self):
913
 
        """We have been disconnected, reset current state.
914
 
 
915
 
        This resets things like _current_request and connected state.
916
 
        """
917
 
        self.disconnect()
918
 
        self._current_request = None
919
 
 
920
713
 
921
714
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
922
715
    """A client medium using simple pipes.
929
722
        self._readable_pipe = readable_pipe
930
723
        self._writeable_pipe = writeable_pipe
931
724
 
932
 
    def _accept_bytes(self, data):
 
725
    def _accept_bytes(self, bytes):
933
726
        """See SmartClientStreamMedium.accept_bytes."""
934
 
        try:
935
 
            self._writeable_pipe.write(data)
936
 
        except IOError as e:
937
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
938
 
                raise errors.ConnectionReset(
939
 
                    "Error trying to write to subprocess", e)
940
 
            raise
941
 
        self._report_activity(len(data), 'write')
 
727
        self._writeable_pipe.write(bytes)
 
728
        self._report_activity(len(bytes), 'write')
942
729
 
943
730
    def _flush(self):
944
731
        """See SmartClientStreamMedium._flush()."""
945
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
946
 
        #       However, testing shows that even when the child process is
947
 
        #       gone, this doesn't error.
948
732
        self._writeable_pipe.flush()
949
733
 
950
734
    def _read_bytes(self, count):
951
735
        """See SmartClientStreamMedium._read_bytes."""
952
736
        bytes_to_read = min(count, _MAX_READ_SIZE)
953
 
        data = self._readable_pipe.read(bytes_to_read)
954
 
        self._report_activity(len(data), 'read')
955
 
        return data
 
737
        bytes = self._readable_pipe.read(bytes_to_read)
 
738
        self._report_activity(len(bytes), 'read')
 
739
        return bytes
956
740
 
957
741
 
958
742
class SSHParams(object):
959
743
    """A set of parameters for starting a remote bzr via SSH."""
960
744
 
961
745
    def __init__(self, host, port=None, username=None, password=None,
962
 
                 bzr_remote_path='bzr'):
 
746
            bzr_remote_path='bzr'):
963
747
        self.host = host
964
748
        self.port = port
965
749
        self.username = username
969
753
 
970
754
class SmartSSHClientMedium(SmartClientStreamMedium):
971
755
    """A client medium using SSH.
972
 
 
973
 
    It delegates IO to a SmartSimplePipesClientMedium or
 
756
    
 
757
    It delegates IO to a SmartClientSocketMedium or
974
758
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
975
759
    """
976
760
 
979
763
 
980
764
        :param ssh_params: A SSHParams instance.
981
765
        :param vendor: An optional override for the ssh vendor to use. See
982
 
            breezy.transport.ssh for details on ssh vendors.
 
766
            bzrlib.transport.ssh for details on ssh vendors.
983
767
        """
984
768
        self._real_medium = None
985
769
        self._ssh_params = ssh_params
998
782
            maybe_port = ''
999
783
        else:
1000
784
            maybe_port = ':%s' % self._ssh_params.port
1001
 
        if self._ssh_params.username is None:
1002
 
            maybe_user = ''
1003
 
        else:
1004
 
            maybe_user = '%s@' % self._ssh_params.username
1005
 
        return "%s(%s://%s%s%s/)" % (
 
785
        return "%s(%s://%s@%s%s/)" % (
1006
786
            self.__class__.__name__,
1007
787
            self._scheme,
1008
 
            maybe_user,
 
788
            self._ssh_params.username,
1009
789
            self._ssh_params.host,
1010
790
            maybe_port)
1011
791
 
1032
812
        else:
1033
813
            vendor = self._vendor
1034
814
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1035
 
                                                  self._ssh_params.password, self._ssh_params.host,
1036
 
                                                  self._ssh_params.port,
1037
 
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1038
 
                                                           '--directory=/', '--allow-writes'])
 
815
                self._ssh_params.password, self._ssh_params.host,
 
816
                self._ssh_params.port,
 
817
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
818
                         '--directory=/', '--allow-writes'])
1039
819
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1040
820
        if io_kind == 'socket':
1041
821
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1048
828
            raise AssertionError(
1049
829
                "Unexpected io_kind %r from %r"
1050
830
                % (io_kind, self._ssh_connection))
1051
 
        for hook in transport.Transport.hooks["post_connect"]:
1052
 
            hook(self)
1053
831
 
1054
832
    def _flush(self):
1055
833
        """See SmartClientStreamMedium._flush()."""
1069
847
 
1070
848
class SmartClientSocketMedium(SmartClientStreamMedium):
1071
849
    """A client medium using a socket.
1072
 
 
 
850
    
1073
851
    This class isn't usable directly.  Use one of its subclasses instead.
1074
852
    """
1075
853
 
1130
908
            port = int(self._port)
1131
909
        try:
1132
910
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1133
 
                                           socket.SOCK_STREAM, 0, 0)
1134
 
        except socket.gaierror as xxx_todo_changeme:
1135
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
911
                socket.SOCK_STREAM, 0, 0)
 
912
        except socket.gaierror, (err_num, err_msg):
1136
913
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1137
 
                                         (self._host, port, err_msg))
 
914
                    (self._host, port, err_msg))
1138
915
        # Initialize err in case there are no addresses returned:
1139
 
        last_err = socket.error("no address found for %s" % self._host)
 
916
        err = socket.error("no address found for %s" % self._host)
1140
917
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1141
918
            try:
1142
919
                self._socket = socket.socket(family, socktype, proto)
1143
920
                self._socket.setsockopt(socket.IPPROTO_TCP,
1144
921
                                        socket.TCP_NODELAY, 1)
1145
922
                self._socket.connect(sockaddr)
1146
 
            except socket.error as err:
 
923
            except socket.error, err:
1147
924
                if self._socket is not None:
1148
925
                    self._socket.close()
1149
926
                self._socket = None
1150
 
                last_err = err
1151
927
                continue
1152
928
            break
1153
929
        if self._socket is None:
1154
930
            # socket errors either have a (string) or (errno, string) as their
1155
931
            # args.
1156
 
            if isinstance(last_err.args, str):
1157
 
                err_msg = last_err.args
 
932
            if type(err.args) is str:
 
933
                err_msg = err.args
1158
934
            else:
1159
 
                err_msg = last_err.args[1]
 
935
                err_msg = err.args[1]
1160
936
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1161
 
                                         (self._host, port, err_msg))
 
937
                    (self._host, port, err_msg))
1162
938
        self._connected = True
1163
 
        for hook in transport.Transport.hooks["post_connect"]:
1164
 
            hook(self)
1165
939
 
1166
940
 
1167
941
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1168
942
    """A client medium for an already connected socket.
1169
 
 
 
943
    
1170
944
    Note that this class will assume it "owns" the socket, so it will close it
1171
945
    when its disconnect method is called.
1172
946
    """
1219
993
        This invokes self._medium._flush to ensure all bytes are transmitted.
1220
994
        """
1221
995
        self._medium._flush()
 
996
 
 
997