/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2018-02-27 14:36:14 UTC
  • mto: This revision was merged to the branch mainline in revision 6866.
  • Revision ID: jelmer@jelmer.uk-20180227143614-2cuc06ngefq77gke
Move to errors.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
bzrlib/transport/smart/__init__.py.
 
24
breezy/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
from __future__ import absolute_import
 
28
 
 
29
import errno
27
30
import os
28
31
import sys
29
 
import urllib
 
32
import time
30
33
 
31
 
from bzrlib.lazy_import import lazy_import
 
34
import breezy
 
35
from ...lazy_import import lazy_import
32
36
lazy_import(globals(), """
33
 
import atexit
 
37
import select
34
38
import socket
35
39
import thread
36
40
import weakref
37
41
 
38
 
from bzrlib import (
 
42
from breezy import (
39
43
    debug,
40
 
    errors,
41
 
    symbol_versioning,
42
44
    trace,
 
45
    transport,
43
46
    ui,
44
47
    urlutils,
45
48
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
47
 
from bzrlib.transport import ssh
 
49
from breezy.i18n import gettext
 
50
from breezy.bzr.smart import client, protocol, request, signals, vfs
 
51
from breezy.transport import ssh
48
52
""")
49
 
from bzrlib import osutils
 
53
from ... import (
 
54
    errors,
 
55
    osutils,
 
56
    )
50
57
 
51
58
# Throughout this module buffer size parameters are either limited to be at
52
59
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
54
61
# from non-sockets as well.
55
62
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
56
63
 
 
64
 
 
65
class HpssVfsRequestNotAllowed(errors.BzrError):
 
66
 
 
67
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
 
68
            "%(method)s, %(arguments)s.")
 
69
 
 
70
    def __init__(self, method, arguments):
 
71
        self.method = method
 
72
        self.arguments = arguments
 
73
 
 
74
 
57
75
def _get_protocol_factory_for_bytes(bytes):
58
76
    """Determine the right protocol factory for 'bytes'.
59
77
 
176
194
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
195
 
178
196
 
 
197
_bad_file_descriptor = (errno.EBADF,)
 
198
if sys.platform == 'win32':
 
199
    # Given on Windows if you pass a closed socket to select.select. Probably
 
200
    # also given if you pass a file handle to select.
 
201
    WSAENOTSOCK = 10038
 
202
    _bad_file_descriptor += (WSAENOTSOCK,)
 
203
 
 
204
 
179
205
class SmartServerStreamMedium(SmartMedium):
180
206
    """Handles smart commands coming over a stream.
181
207
 
194
220
        the stream.  See also the _push_back method.
195
221
    """
196
222
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
223
    _timer = time.time
 
224
 
 
225
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
226
        """Construct new server.
199
227
 
200
228
        :param backing_transport: Transport for the directory served.
203
231
        self.backing_transport = backing_transport
204
232
        self.root_client_path = root_client_path
205
233
        self.finished = False
 
234
        if timeout is None:
 
235
            raise AssertionError('You must supply a timeout.')
 
236
        self._client_timeout = timeout
 
237
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
238
        SmartMedium.__init__(self)
207
239
 
208
240
    def serve(self):
214
246
            while not self.finished:
215
247
                server_protocol = self._build_protocol()
216
248
                self._serve_one_request(server_protocol)
217
 
        except Exception, e:
 
249
        except errors.ConnectionTimeout as e:
 
250
            trace.note('%s' % (e,))
 
251
            trace.log_exception_quietly()
 
252
            self._disconnect_client()
 
253
            # We reported it, no reason to make a big fuss.
 
254
            return
 
255
        except Exception as e:
218
256
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
257
            raise
 
258
        self._disconnect_client()
 
259
 
 
260
    def _stop_gracefully(self):
 
261
        """When we finish this message, stop looking for more."""
 
262
        trace.mutter('Stopping %s' % (self,))
 
263
        self.finished = True
 
264
 
 
265
    def _disconnect_client(self):
 
266
        """Close the current connection. We stopped due to a timeout/etc."""
 
267
        # The default implementation is a no-op, because that is all we used to
 
268
        # do when disconnecting from a client. I suppose we never had the
 
269
        # *server* initiate a disconnect, before
 
270
 
 
271
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
272
        """Wait for more bytes to be read, but timeout if none available.
 
273
 
 
274
        This allows us to detect idle connections, and stop trying to read from
 
275
        them, without setting the socket itself to non-blocking. This also
 
276
        allows us to specify when we watch for idle timeouts.
 
277
 
 
278
        :return: Did we timeout? (True if we timed out, False if there is data
 
279
            to be read)
 
280
        """
 
281
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
282
 
221
283
    def _build_protocol(self):
222
284
        """Identifies the version of the incoming request, and returns an
227
289
 
228
290
        :returns: a SmartServerRequestProtocol.
229
291
        """
 
292
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
293
        if self.finished:
 
294
            # We're stopping, so don't try to do any more work
 
295
            return None
230
296
        bytes = self._get_line()
231
297
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
298
        protocol = protocol_factory(
234
300
        protocol.accept_bytes(unused_bytes)
235
301
        return protocol
236
302
 
 
303
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
304
        """select() on a file descriptor, waiting for nonblocking read()
 
305
 
 
306
        This will raise a ConnectionTimeout exception if we do not get a
 
307
        readable handle before timeout_seconds.
 
308
        :return: None
 
309
        """
 
310
        t_end = self._timer() + timeout_seconds
 
311
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
312
        rs = xs = None
 
313
        while not rs and not xs and self._timer() < t_end:
 
314
            if self.finished:
 
315
                return
 
316
            try:
 
317
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
318
            except (select.error, socket.error) as e:
 
319
                err = getattr(e, 'errno', None)
 
320
                if err is None and getattr(e, 'args', None) is not None:
 
321
                    # select.error doesn't have 'errno', it just has args[0]
 
322
                    err = e.args[0]
 
323
                if err in _bad_file_descriptor:
 
324
                    return # Not a socket indicates read() will fail
 
325
                elif err == errno.EINTR:
 
326
                    # Interrupted, keep looping.
 
327
                    continue
 
328
                raise
 
329
        if rs or xs:
 
330
            return
 
331
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
332
                                       % (timeout_seconds,))
 
333
 
237
334
    def _serve_one_request(self, protocol):
238
335
        """Read one request from input, process, send back a response.
239
336
 
240
337
        :param protocol: a SmartServerRequestProtocol.
241
338
        """
 
339
        if protocol is None:
 
340
            return
242
341
        try:
243
342
            self._serve_one_request_unguarded(protocol)
244
343
        except KeyboardInterrupt:
245
344
            raise
246
 
        except Exception, e:
 
345
        except Exception as e:
247
346
            self.terminate_due_to_error()
248
347
 
249
348
    def terminate_due_to_error(self):
260
359
 
261
360
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
361
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
362
    def __init__(self, sock, backing_transport, root_client_path='/',
 
363
                 timeout=None):
264
364
        """Constructor.
265
365
 
266
366
        :param sock: the socket the server will read from.  It will be put
267
367
            into blocking mode.
268
368
        """
269
369
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
370
            self, backing_transport, root_client_path=root_client_path,
 
371
            timeout=timeout)
271
372
        sock.setblocking(True)
272
373
        self.socket = sock
 
374
        # Get the getpeername now, as we might be closed later when we care.
 
375
        try:
 
376
            self._client_info = sock.getpeername()
 
377
        except socket.error:
 
378
            self._client_info = '<unknown>'
 
379
 
 
380
    def __str__(self):
 
381
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
382
 
 
383
    def __repr__(self):
 
384
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
385
            self._client_info)
273
386
 
274
387
    def _serve_one_request_unguarded(self, protocol):
275
388
        while protocol.next_read_size():
284
397
 
285
398
        self._push_back(protocol.unused_data)
286
399
 
 
400
    def _disconnect_client(self):
 
401
        """Close the current connection. We stopped due to a timeout/etc."""
 
402
        self.socket.close()
 
403
 
 
404
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
405
        """Wait for more bytes to be read, but timeout if none available.
 
406
 
 
407
        This allows us to detect idle connections, and stop trying to read from
 
408
        them, without setting the socket itself to non-blocking. This also
 
409
        allows us to specify when we watch for idle timeouts.
 
410
 
 
411
        :return: None, this will raise ConnectionTimeout if we time out before
 
412
            data is available.
 
413
        """
 
414
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
415
 
287
416
    def _read_bytes(self, desired_count):
288
417
        return osutils.read_bytes_from_socket(
289
418
            self.socket, self._report_activity)
306
435
 
307
436
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
308
437
 
309
 
    def __init__(self, in_file, out_file, backing_transport):
 
438
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
310
439
        """Construct new server.
311
440
 
312
441
        :param in_file: Python file from which requests can be read.
313
442
        :param out_file: Python file to write responses.
314
443
        :param backing_transport: Transport for the directory served.
315
444
        """
316
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
445
        SmartServerStreamMedium.__init__(self, backing_transport,
 
446
            timeout=timeout)
317
447
        if sys.platform == 'win32':
318
448
            # force binary mode for files
319
449
            import msvcrt
324
454
        self._in = in_file
325
455
        self._out = out_file
326
456
 
 
457
    def serve(self):
 
458
        """See SmartServerStreamMedium.serve"""
 
459
        # This is the regular serve, except it adds signal trapping for soft
 
460
        # shutdown.
 
461
        stop_gracefully = self._stop_gracefully
 
462
        signals.register_on_hangup(id(self), stop_gracefully)
 
463
        try:
 
464
            return super(SmartServerPipeStreamMedium, self).serve()
 
465
        finally:
 
466
            signals.unregister_on_hangup(id(self))
 
467
 
327
468
    def _serve_one_request_unguarded(self, protocol):
328
469
        while True:
329
470
            # We need to be careful not to read past the end of the current
342
483
                return
343
484
            protocol.accept_bytes(bytes)
344
485
 
 
486
    def _disconnect_client(self):
 
487
        self._in.close()
 
488
        self._out.flush()
 
489
        self._out.close()
 
490
 
 
491
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
492
        """Wait for more bytes to be read, but timeout if none available.
 
493
 
 
494
        This allows us to detect idle connections, and stop trying to read from
 
495
        them, without setting the socket itself to non-blocking. This also
 
496
        allows us to specify when we watch for idle timeouts.
 
497
 
 
498
        :return: None, this will raise ConnectionTimeout if we time out before
 
499
            data is available.
 
500
        """
 
501
        if (getattr(self._in, 'fileno', None) is None
 
502
            or sys.platform == 'win32'):
 
503
            # You can't select() file descriptors on Windows.
 
504
            return
 
505
        return self._wait_on_descriptor(self._in, timeout_seconds)
 
506
 
345
507
    def _read_bytes(self, desired_count):
346
508
        return self._in.read(desired_count)
347
509
 
491
653
        return self._medium._get_line()
492
654
 
493
655
 
 
656
class _VfsRefuser(object):
 
657
    """An object that refuses all VFS requests.
 
658
 
 
659
    """
 
660
 
 
661
    def __init__(self):
 
662
        client._SmartClient.hooks.install_named_hook(
 
663
            'call', self.check_vfs, 'vfs refuser')
 
664
 
 
665
    def check_vfs(self, params):
 
666
        try:
 
667
            request_method = request.request_handlers.get(params.method)
 
668
        except KeyError:
 
669
            # A method we don't know about doesn't count as a VFS method.
 
670
            return
 
671
        if issubclass(request_method, vfs.VfsRequest):
 
672
            raise HpssVfsRequestNotAllowed(params.method, params.args)
 
673
 
 
674
 
494
675
class _DebugCounter(object):
495
676
    """An object that counts the HPSS calls made to each client medium.
496
677
 
497
 
    When a medium is garbage-collected, or failing that when atexit functions
498
 
    are run, the total number of calls made on that medium are reported via
499
 
    trace.note.
 
678
    When a medium is garbage-collected, or failing that when
 
679
    breezy.global_state exits, the total number of calls made on that medium
 
680
    are reported via trace.note.
500
681
    """
501
682
 
502
683
    def __init__(self):
503
684
        self.counts = weakref.WeakKeyDictionary()
504
685
        client._SmartClient.hooks.install_named_hook(
505
686
            'call', self.increment_call_count, 'hpss call counter')
506
 
        atexit.register(self.flush_all)
 
687
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
507
688
 
508
689
    def track(self, medium):
509
690
        """Start tracking calls made to a medium.
543
724
        value['count'] = 0
544
725
        value['vfs_count'] = 0
545
726
        if count != 0:
546
 
            trace.note('HPSS calls: %d (%d vfs) %s',
547
 
                       count, vfs_count, medium_repr)
 
727
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
728
                       count, vfs_count, medium_repr))
548
729
 
549
730
    def flush_all(self):
550
731
        for ref in list(self.counts.keys()):
551
732
            self.done(ref)
552
733
 
553
734
_debug_counter = None
 
735
_vfs_refuser = None
554
736
 
555
737
 
556
738
class SmartClientMedium(SmartMedium):
573
755
            if _debug_counter is None:
574
756
                _debug_counter = _DebugCounter()
575
757
            _debug_counter.track(self)
 
758
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
759
            global _vfs_refuser
 
760
            if _vfs_refuser is None:
 
761
                _vfs_refuser = _VfsRefuser()
576
762
 
577
763
    def _is_remote_before(self, version_tuple):
578
764
        """Is it possible the remote side supports RPCs for a given version?
631
817
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
632
818
                client_protocol.query_version()
633
819
                self._done_hello = True
634
 
            except errors.SmartProtocolError, e:
 
820
            except errors.SmartProtocolError as e:
635
821
                # Cache the error, just like we would cache a successful
636
822
                # result.
637
823
                self._protocol_version_error = e
669
855
        """
670
856
        medium_base = urlutils.join(self.base, '/')
671
857
        rel_url = urlutils.relative_url(medium_base, transport.base)
672
 
        return urllib.unquote(rel_url)
 
858
        return urlutils.unquote(rel_url)
673
859
 
674
860
 
675
861
class SmartClientStreamMedium(SmartClientMedium):
710
896
        """
711
897
        return SmartClientStreamMediumRequest(self)
712
898
 
 
899
    def reset(self):
 
900
        """We have been disconnected, reset current state.
 
901
 
 
902
        This resets things like _current_request and connected state.
 
903
        """
 
904
        self.disconnect()
 
905
        self._current_request = None
 
906
 
713
907
 
714
908
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
715
909
    """A client medium using simple pipes.
716
910
 
717
911
    This client does not manage the pipes: it assumes they will always be open.
718
 
 
719
 
    Note that if readable_pipe.read might raise IOError or OSError with errno
720
 
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
721
 
    (such as used for sys.stdin) are safe.
722
912
    """
723
913
 
724
914
    def __init__(self, readable_pipe, writeable_pipe, base):
728
918
 
729
919
    def _accept_bytes(self, bytes):
730
920
        """See SmartClientStreamMedium.accept_bytes."""
731
 
        self._writeable_pipe.write(bytes)
 
921
        try:
 
922
            self._writeable_pipe.write(bytes)
 
923
        except IOError as e:
 
924
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
925
                raise errors.ConnectionReset(
 
926
                    "Error trying to write to subprocess", e)
 
927
            raise
732
928
        self._report_activity(len(bytes), 'write')
733
929
 
734
930
    def _flush(self):
735
931
        """See SmartClientStreamMedium._flush()."""
 
932
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
933
        #       However, testing shows that even when the child process is
 
934
        #       gone, this doesn't error.
736
935
        self._writeable_pipe.flush()
737
936
 
738
937
    def _read_bytes(self, count):
739
938
        """See SmartClientStreamMedium._read_bytes."""
740
 
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
 
939
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
940
        bytes = self._readable_pipe.read(bytes_to_read)
741
941
        self._report_activity(len(bytes), 'read')
742
942
        return bytes
743
943
 
744
944
 
 
945
class SSHParams(object):
 
946
    """A set of parameters for starting a remote bzr via SSH."""
 
947
 
 
948
    def __init__(self, host, port=None, username=None, password=None,
 
949
            bzr_remote_path='bzr'):
 
950
        self.host = host
 
951
        self.port = port
 
952
        self.username = username
 
953
        self.password = password
 
954
        self.bzr_remote_path = bzr_remote_path
 
955
 
 
956
 
745
957
class SmartSSHClientMedium(SmartClientStreamMedium):
746
 
    """A client medium using SSH."""
747
 
 
748
 
    def __init__(self, host, port=None, username=None, password=None,
749
 
            base=None, vendor=None, bzr_remote_path=None):
 
958
    """A client medium using SSH.
 
959
 
 
960
    It delegates IO to a SmartSimplePipesClientMedium or
 
961
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
962
    """
 
963
 
 
964
    def __init__(self, base, ssh_params, vendor=None):
750
965
        """Creates a client that will connect on the first use.
751
966
 
 
967
        :param ssh_params: A SSHParams instance.
752
968
        :param vendor: An optional override for the ssh vendor to use. See
753
 
            bzrlib.transport.ssh for details on ssh vendors.
 
969
            breezy.transport.ssh for details on ssh vendors.
754
970
        """
755
 
        self._connected = False
756
 
        self._host = host
757
 
        self._password = password
758
 
        self._port = port
759
 
        self._username = username
 
971
        self._real_medium = None
 
972
        self._ssh_params = ssh_params
760
973
        # for the benefit of progress making a short description of this
761
974
        # transport
762
975
        self._scheme = 'bzr+ssh'
764
977
        # _DebugCounter so we have to store all the values used in our repr
765
978
        # method before calling the super init.
766
979
        SmartClientStreamMedium.__init__(self, base)
767
 
        self._read_from = None
 
980
        self._vendor = vendor
768
981
        self._ssh_connection = None
769
 
        self._vendor = vendor
770
 
        self._write_to = None
771
 
        self._bzr_remote_path = bzr_remote_path
772
982
 
773
983
    def __repr__(self):
774
 
        if self._port is None:
 
984
        if self._ssh_params.port is None:
775
985
            maybe_port = ''
776
986
        else:
777
 
            maybe_port = ':%s' % self._port
778
 
        return "%s(%s://%s@%s%s/)" % (
 
987
            maybe_port = ':%s' % self._ssh_params.port
 
988
        if self._ssh_params.username is None:
 
989
            maybe_user = ''
 
990
        else:
 
991
            maybe_user = '%s@' % self._ssh_params.username
 
992
        return "%s(%s://%s%s%s/)" % (
779
993
            self.__class__.__name__,
780
994
            self._scheme,
781
 
            self._username,
782
 
            self._host,
 
995
            maybe_user,
 
996
            self._ssh_params.host,
783
997
            maybe_port)
784
998
 
785
999
    def _accept_bytes(self, bytes):
786
1000
        """See SmartClientStreamMedium.accept_bytes."""
787
1001
        self._ensure_connection()
788
 
        self._write_to.write(bytes)
789
 
        self._report_activity(len(bytes), 'write')
 
1002
        self._real_medium.accept_bytes(bytes)
790
1003
 
791
1004
    def disconnect(self):
792
1005
        """See SmartClientMedium.disconnect()."""
793
 
        if not self._connected:
794
 
            return
795
 
        self._read_from.close()
796
 
        self._write_to.close()
797
 
        self._ssh_connection.close()
798
 
        self._connected = False
 
1006
        if self._real_medium is not None:
 
1007
            self._real_medium.disconnect()
 
1008
            self._real_medium = None
 
1009
        if self._ssh_connection is not None:
 
1010
            self._ssh_connection.close()
 
1011
            self._ssh_connection = None
799
1012
 
800
1013
    def _ensure_connection(self):
801
1014
        """Connect this medium if not already connected."""
802
 
        if self._connected:
 
1015
        if self._real_medium is not None:
803
1016
            return
804
1017
        if self._vendor is None:
805
1018
            vendor = ssh._get_ssh_vendor()
806
1019
        else:
807
1020
            vendor = self._vendor
808
 
        self._ssh_connection = vendor.connect_ssh(self._username,
809
 
                self._password, self._host, self._port,
810
 
                command=[self._bzr_remote_path, 'serve', '--inet',
 
1021
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1022
                self._ssh_params.password, self._ssh_params.host,
 
1023
                self._ssh_params.port,
 
1024
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
811
1025
                         '--directory=/', '--allow-writes'])
812
 
        self._read_from, self._write_to = \
813
 
            self._ssh_connection.get_filelike_channels()
814
 
        self._connected = True
 
1026
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1027
        if io_kind == 'socket':
 
1028
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1029
                self.base, io_object)
 
1030
        elif io_kind == 'pipes':
 
1031
            read_from, write_to = io_object
 
1032
            self._real_medium = SmartSimplePipesClientMedium(
 
1033
                read_from, write_to, self.base)
 
1034
        else:
 
1035
            raise AssertionError(
 
1036
                "Unexpected io_kind %r from %r"
 
1037
                % (io_kind, self._ssh_connection))
 
1038
        for hook in transport.Transport.hooks["post_connect"]:
 
1039
            hook(self)
815
1040
 
816
1041
    def _flush(self):
817
1042
        """See SmartClientStreamMedium._flush()."""
818
 
        self._write_to.flush()
 
1043
        self._real_medium._flush()
819
1044
 
820
1045
    def _read_bytes(self, count):
821
1046
        """See SmartClientStreamMedium.read_bytes."""
822
 
        if not self._connected:
 
1047
        if self._real_medium is None:
823
1048
            raise errors.MediumNotConnected(self)
824
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
825
 
        bytes = self._read_from.read(bytes_to_read)
826
 
        self._report_activity(len(bytes), 'read')
827
 
        return bytes
 
1049
        return self._real_medium.read_bytes(count)
828
1050
 
829
1051
 
830
1052
# Port 4155 is the default port for bzr://, registered with IANA.
832
1054
BZR_DEFAULT_PORT = 4155
833
1055
 
834
1056
 
835
 
class SmartTCPClientMedium(SmartClientStreamMedium):
836
 
    """A client medium using TCP."""
837
 
 
838
 
    def __init__(self, host, port, base):
839
 
        """Creates a client that will connect on the first use."""
 
1057
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1058
    """A client medium using a socket.
 
1059
 
 
1060
    This class isn't usable directly.  Use one of its subclasses instead.
 
1061
    """
 
1062
 
 
1063
    def __init__(self, base):
840
1064
        SmartClientStreamMedium.__init__(self, base)
 
1065
        self._socket = None
841
1066
        self._connected = False
842
 
        self._host = host
843
 
        self._port = port
844
 
        self._socket = None
845
1067
 
846
1068
    def _accept_bytes(self, bytes):
847
1069
        """See SmartClientMedium.accept_bytes."""
848
1070
        self._ensure_connection()
849
1071
        osutils.send_all(self._socket, bytes, self._report_activity)
850
1072
 
 
1073
    def _ensure_connection(self):
 
1074
        """Connect this medium if not already connected."""
 
1075
        raise NotImplementedError(self._ensure_connection)
 
1076
 
 
1077
    def _flush(self):
 
1078
        """See SmartClientStreamMedium._flush().
 
1079
 
 
1080
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1081
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1082
        future.
 
1083
        """
 
1084
 
 
1085
    def _read_bytes(self, count):
 
1086
        """See SmartClientMedium.read_bytes."""
 
1087
        if not self._connected:
 
1088
            raise errors.MediumNotConnected(self)
 
1089
        return osutils.read_bytes_from_socket(
 
1090
            self._socket, self._report_activity)
 
1091
 
851
1092
    def disconnect(self):
852
1093
        """See SmartClientMedium.disconnect()."""
853
1094
        if not self._connected:
856
1097
        self._socket = None
857
1098
        self._connected = False
858
1099
 
 
1100
 
 
1101
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1102
    """A client medium that creates a TCP connection."""
 
1103
 
 
1104
    def __init__(self, host, port, base):
 
1105
        """Creates a client that will connect on the first use."""
 
1106
        SmartClientSocketMedium.__init__(self, base)
 
1107
        self._host = host
 
1108
        self._port = port
 
1109
 
859
1110
    def _ensure_connection(self):
860
1111
        """Connect this medium if not already connected."""
861
1112
        if self._connected:
867
1118
        try:
868
1119
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
869
1120
                socket.SOCK_STREAM, 0, 0)
870
 
        except socket.gaierror, (err_num, err_msg):
 
1121
        except socket.gaierror as xxx_todo_changeme:
 
1122
            (err_num, err_msg) = xxx_todo_changeme.args
871
1123
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
872
1124
                    (self._host, port, err_msg))
873
1125
        # Initialize err in case there are no addresses returned:
878
1130
                self._socket.setsockopt(socket.IPPROTO_TCP,
879
1131
                                        socket.TCP_NODELAY, 1)
880
1132
                self._socket.connect(sockaddr)
881
 
            except socket.error, err:
 
1133
            except socket.error as err:
882
1134
                if self._socket is not None:
883
1135
                    self._socket.close()
884
1136
                self._socket = None
887
1139
        if self._socket is None:
888
1140
            # socket errors either have a (string) or (errno, string) as their
889
1141
            # args.
890
 
            if type(err.args) is str:
 
1142
            if isinstance(err.args, str):
891
1143
                err_msg = err.args
892
1144
            else:
893
1145
                err_msg = err.args[1]
894
1146
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
895
1147
                    (self._host, port, err_msg))
896
1148
        self._connected = True
897
 
 
898
 
    def _flush(self):
899
 
        """See SmartClientStreamMedium._flush().
900
 
 
901
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
902
 
        add a means to do a flush, but that can be done in the future.
903
 
        """
904
 
 
905
 
    def _read_bytes(self, count):
906
 
        """See SmartClientMedium.read_bytes."""
907
 
        if not self._connected:
908
 
            raise errors.MediumNotConnected(self)
909
 
        return osutils.read_bytes_from_socket(
910
 
            self._socket, self._report_activity)
 
1149
        for hook in transport.Transport.hooks["post_connect"]:
 
1150
            hook(self)
 
1151
 
 
1152
 
 
1153
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1154
    """A client medium for an already connected socket.
 
1155
    
 
1156
    Note that this class will assume it "owns" the socket, so it will close it
 
1157
    when its disconnect method is called.
 
1158
    """
 
1159
 
 
1160
    def __init__(self, base, sock):
 
1161
        SmartClientSocketMedium.__init__(self, base)
 
1162
        self._socket = sock
 
1163
        self._connected = True
 
1164
 
 
1165
    def _ensure_connection(self):
 
1166
        # Already connected, by definition!  So nothing to do.
 
1167
        pass
911
1168
 
912
1169
 
913
1170
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
948
1205
        This invokes self._medium._flush to ensure all bytes are transmitted.
949
1206
        """
950
1207
        self._medium._flush()
951
 
 
952