/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart.py

  • Committer: John Arbash Meinel
  • Date: 2006-09-16 02:26:44 UTC
  • mfrom: (2017 +trunk)
  • mto: This revision was merged to the branch mainline in revision 2020.
  • Revision ID: john@arbash-meinel.com-20060916022644-e19857e642b00a9e
[merge] bzr.dev 2017

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Smart-server protocol, client and server.
 
18
 
 
19
Requests are sent as a command and list of arguments, followed by optional
 
20
bulk body data.  Responses are similarly a response and list of arguments,
 
21
followed by bulk body data. ::
 
22
 
 
23
  SEP := '\001'
 
24
    Fields are separated by Ctrl-A.
 
25
  BULK_DATA := CHUNK+ TRAILER
 
26
    Chunks can be repeated as many times as necessary.
 
27
  CHUNK := CHUNK_LEN CHUNK_BODY
 
28
  CHUNK_LEN := DIGIT+ NEWLINE
 
29
    Gives the number of bytes in the following chunk.
 
30
  CHUNK_BODY := BYTE[chunk_len]
 
31
  TRAILER := SUCCESS_TRAILER | ERROR_TRAILER
 
32
  SUCCESS_TRAILER := 'done' NEWLINE
 
33
  ERROR_TRAILER := 
 
34
 
 
35
Paths are passed across the network.  The client needs to see a namespace that
 
36
includes any repository that might need to be referenced, and the client needs
 
37
to know about a root directory beyond which it cannot ascend.
 
38
 
 
39
Servers run over ssh will typically want to be able to access any path the user 
 
40
can access.  Public servers on the other hand (which might be over http, ssh
 
41
or tcp) will typically want to restrict access to only a particular directory 
 
42
and its children, so will want to do a software virtual root at that level.
 
43
In other words they'll want to rewrite incoming paths to be under that level
 
44
(and prevent escaping using ../ tricks.)
 
45
 
 
46
URLs that include ~ should probably be passed across to the server verbatim
 
47
and the server can expand them.  This will proably not be meaningful when 
 
48
limited to a directory?
 
49
"""
 
50
 
 
51
 
 
52
 
 
53
# TODO: A plain integer from query_version is too simple; should give some
 
54
# capabilities too?
 
55
 
 
56
# TODO: Server should probably catch exceptions within itself and send them
 
57
# back across the network.  (But shouldn't catch KeyboardInterrupt etc)
 
58
# Also needs to somehow report protocol errors like bad requests.  Need to
 
59
# consider how we'll handle error reporting, e.g. if we get halfway through a
 
60
# bulk transfer and then something goes wrong.
 
61
 
 
62
# TODO: Standard marker at start of request/response lines?
 
63
 
 
64
# TODO: Make each request and response self-validatable, e.g. with checksums.
 
65
#
 
66
# TODO: get/put objects could be changed to gradually read back the data as it
 
67
# comes across the network
 
68
#
 
69
# TODO: What should the server do if it hits an error and has to terminate?
 
70
#
 
71
# TODO: is it useful to allow multiple chunks in the bulk data?
 
72
#
 
73
# TODO: If we get an exception during transmission of bulk data we can't just
 
74
# emit the exception because it won't be seen.
 
75
#   John proposes:  I think it would be worthwhile to have a header on each
 
76
#   chunk, that indicates it is another chunk. Then you can send an 'error'
 
77
#   chunk as long as you finish the previous chunk.
 
78
#
 
79
# TODO: Clone method on Transport; should work up towards parent directory;
 
80
# unclear how this should be stored or communicated to the server... maybe
 
81
# just pass it on all relevant requests?
 
82
#
 
83
# TODO: Better name than clone() for changing between directories.  How about
 
84
# open_dir or change_dir or chdir?
 
85
#
 
86
# TODO: Is it really good to have the notion of current directory within the
 
87
# connection?  Perhaps all Transports should factor out a common connection
 
88
# from the thing that has the directory context?
 
89
#
 
90
# TODO: Pull more things common to sftp and ssh to a higher level.
 
91
#
 
92
# TODO: The server that manages a connection should be quite small and retain
 
93
# minimum state because each of the requests are supposed to be stateless.
 
94
# Then we can write another implementation that maps to http.
 
95
#
 
96
# TODO: What to do when a client connection is garbage collected?  Maybe just
 
97
# abruptly drop the connection?
 
98
#
 
99
# TODO: Server in some cases will need to restrict access to files outside of
 
100
# a particular root directory.  LocalTransport doesn't do anything to stop you
 
101
# ascending above the base directory, so we need to prevent paths
 
102
# containing '..' in either the server or transport layers.  (Also need to
 
103
# consider what happens if someone creates a symlink pointing outside the 
 
104
# directory tree...)
 
105
#
 
106
# TODO: Server should rebase absolute paths coming across the network to put
 
107
# them under the virtual root, if one is in use.  LocalTransport currently
 
108
# doesn't do that; if you give it an absolute path it just uses it.
 
109
 
110
# XXX: Arguments can't contain newlines or ascii; possibly we should e.g.
 
111
# urlescape them instead.  Indeed possibly this should just literally be
 
112
# http-over-ssh.
 
113
#
 
114
# FIXME: This transport, with several others, has imperfect handling of paths
 
115
# within urls.  It'd probably be better for ".." from a root to raise an error
 
116
# rather than return the same directory as we do at present.
 
117
#
 
118
# TODO: Rather than working at the Transport layer we want a Branch,
 
119
# Repository or BzrDir objects that talk to a server.
 
120
#
 
121
# TODO: Probably want some way for server commands to gradually produce body
 
122
# data rather than passing it as a string; they could perhaps pass an
 
123
# iterator-like callback that will gradually yield data; it probably needs a
 
124
# close() method that will always be closed to do any necessary cleanup.
 
125
#
 
126
# TODO: Split the actual smart server from the ssh encoding of it.
 
127
#
 
128
# TODO: Perhaps support file-level readwrite operations over the transport
 
129
# too.
 
130
#
 
131
# TODO: SmartBzrDir class, proxying all Branch etc methods across to another
 
132
# branch doing file-level operations.
 
133
#
 
134
# TODO: jam 20060915 _decode_tuple is acting directly on input over
 
135
#       the socket, and it assumes everything is UTF8 sections separated
 
136
#       by \001. Which means a request like '\002' Will abort the connection
 
137
#       because of a UnicodeDecodeError. It does look like invalid data will
 
138
#       kill the SmartStreamServer, but only with an abort + exception, and 
 
139
#       the overall server shouldn't die.
 
140
 
 
141
from cStringIO import StringIO
 
142
import errno
 
143
import os
 
144
import socket
 
145
import sys
 
146
import tempfile
 
147
import threading
 
148
import urllib
 
149
import urlparse
 
150
 
 
151
from bzrlib import (
 
152
    bzrdir,
 
153
    errors,
 
154
    revision,
 
155
    transport,
 
156
    trace,
 
157
    urlutils,
 
158
    )
 
159
from bzrlib.bundle.serializer import write_bundle
 
160
from bzrlib.trace import mutter
 
161
from bzrlib.transport import local
 
162
 
 
163
# must do this otherwise urllib can't parse the urls properly :(
 
164
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh']:
 
165
    transport.register_urlparse_netloc_protocol(scheme)
 
166
del scheme
 
167
 
 
168
 
 
169
def _recv_tuple(from_file):
 
170
    req_line = from_file.readline()
 
171
    return _decode_tuple(req_line)
 
172
 
 
173
 
 
174
def _decode_tuple(req_line):
 
175
    if req_line == None or req_line == '':
 
176
        return None
 
177
    if req_line[-1] != '\n':
 
178
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
179
    return tuple((a.decode('utf-8') for a in req_line[:-1].split('\x01')))
 
180
 
 
181
 
 
182
def _send_tuple(to_file, args):
 
183
    # XXX: this will be inefficient.  Just ask Robert.
 
184
    to_file.write('\x01'.join((a.encode('utf-8') for a in args)) + '\n')
 
185
    to_file.flush()
 
186
 
 
187
 
 
188
class SmartProtocolBase(object):
 
189
    """Methods common to client and server"""
 
190
 
 
191
    def _send_bulk_data(self, body):
 
192
        """Send chunked body data"""
 
193
        assert isinstance(body, str)
 
194
        self._out.write('%d\n' % len(body))
 
195
        self._out.write(body)
 
196
        self._out.write('done\n')
 
197
        self._out.flush()
 
198
 
 
199
    # TODO: this only actually accomodates a single block; possibly should support
 
200
    # multiple chunks?
 
201
    def _recv_bulk(self):
 
202
        chunk_len = self._in.readline()
 
203
        try:
 
204
            chunk_len = int(chunk_len)
 
205
        except ValueError:
 
206
            raise errors.SmartProtocolError("bad chunk length line %r" % chunk_len)
 
207
        bulk = self._in.read(chunk_len)
 
208
        if len(bulk) != chunk_len:
 
209
            raise errors.SmartProtocolError("short read fetching bulk data chunk")
 
210
        self._recv_trailer()
 
211
        return bulk
 
212
 
 
213
    def _recv_tuple(self):
 
214
        return _recv_tuple(self._in)
 
215
 
 
216
    def _recv_trailer(self):
 
217
        resp = self._recv_tuple()
 
218
        if resp == ('done', ):
 
219
            return
 
220
        else:
 
221
            self._translate_error(resp)
 
222
 
 
223
 
 
224
class SmartStreamServer(SmartProtocolBase):
 
225
    """Handles smart commands coming over a stream.
 
226
 
 
227
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
228
    in-process fifo for testing.
 
229
 
 
230
    One instance is created for each connected client; it can serve multiple
 
231
    requests in the lifetime of the connection.
 
232
 
 
233
    The server passes requests through to an underlying backing transport, 
 
234
    which will typically be a LocalTransport looking at the server's filesystem.
 
235
    """
 
236
 
 
237
    def __init__(self, in_file, out_file, backing_transport):
 
238
        """Construct new server.
 
239
 
 
240
        :param in_file: Python file from which requests can be read.
 
241
        :param out_file: Python file to write responses.
 
242
        :param backing_transport: Transport for the directory served.
 
243
        """
 
244
        self._in = in_file
 
245
        self._out = out_file
 
246
        self.smart_server = SmartServer(backing_transport)
 
247
        # server can call back to us to get bulk data - this is not really
 
248
        # ideal, they should get it per request instead
 
249
        self.smart_server._recv_body = self._recv_bulk
 
250
 
 
251
    def _recv_tuple(self):
 
252
        """Read a request from the client and return as a tuple.
 
253
        
 
254
        Returns None at end of file (if the client closed the connection.)
 
255
        """
 
256
        return _recv_tuple(self._in)
 
257
 
 
258
    def _send_tuple(self, args):
 
259
        """Send response header"""
 
260
        return _send_tuple(self._out, args)
 
261
 
 
262
    def _send_error_and_disconnect(self, exception):
 
263
        self._send_tuple(('error', str(exception)))
 
264
        self._out.flush()
 
265
        ## self._out.close()
 
266
        ## self._in.close()
 
267
 
 
268
    def _serve_one_request(self):
 
269
        """Read one request from input, process, send back a response.
 
270
        
 
271
        :return: False if the server should terminate, otherwise None.
 
272
        """
 
273
        req_args = self._recv_tuple()
 
274
        if req_args == None:
 
275
            # client closed connection
 
276
            return False  # shutdown server
 
277
        try:
 
278
            response = self.smart_server.dispatch_command(req_args[0], req_args[1:])
 
279
            self._send_tuple(response.args)
 
280
            if response.body is not None:
 
281
                self._send_bulk_data(response.body)
 
282
        except KeyboardInterrupt:
 
283
            raise
 
284
        except Exception, e:
 
285
            # everything else: pass to client, flush, and quit
 
286
            self._send_error_and_disconnect(e)
 
287
            return False
 
288
 
 
289
    def serve(self):
 
290
        """Serve requests until the client disconnects."""
 
291
        # Keep a reference to stderr because the sys module's globals get set to
 
292
        # None during interpreter shutdown.
 
293
        from sys import stderr
 
294
        try:
 
295
            while self._serve_one_request() != False:
 
296
                pass
 
297
        except Exception, e:
 
298
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
299
            raise
 
300
 
 
301
 
 
302
class SmartServerResponse(object):
 
303
    """Response generated by SmartServer."""
 
304
 
 
305
    def __init__(self, args, body=None):
 
306
        self.args = args
 
307
        self.body = body
 
308
 
 
309
 
 
310
class SmartServer(object):
 
311
    """Protocol logic for smart server.
 
312
    
 
313
    This doesn't handle serialization at all, it just processes requests and
 
314
    creates responses.
 
315
    """
 
316
 
 
317
    # TODO: Better way of representing the body for commands that take it,
 
318
    # and allow it to be streamed into the server.
 
319
    
 
320
    def __init__(self, backing_transport):
 
321
        self._backing_transport = backing_transport
 
322
        
 
323
    def do_hello(self):
 
324
        """Answer a version request with my version."""
 
325
        return SmartServerResponse(('ok', '1'))
 
326
 
 
327
    def do_has(self, relpath):
 
328
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
 
329
        return SmartServerResponse((r,))
 
330
 
 
331
    def do_get(self, relpath):
 
332
        backing_bytes = self._backing_transport.get_bytes(relpath)
 
333
        return SmartServerResponse(('ok',), backing_bytes)
 
334
 
 
335
    def _deserialise_optional_mode(self, mode):
 
336
        if mode == '':
 
337
            return None
 
338
        else:
 
339
            return int(mode)
 
340
 
 
341
    def do_append(self, relpath, mode):
 
342
        old_length = self._backing_transport.append_bytes(
 
343
            relpath, self._recv_body(), self._deserialise_optional_mode(mode))
 
344
        return SmartServerResponse(('appended', '%d' % old_length))
 
345
 
 
346
    def do_delete(self, relpath):
 
347
        self._backing_transport.delete(relpath)
 
348
 
 
349
    def do_iter_files_recursive(self, abspath):
 
350
        # XXX: the path handling needs some thought.
 
351
        #relpath = self._backing_transport.relpath(abspath)
 
352
        transport = self._backing_transport.clone(abspath)
 
353
        filenames = transport.iter_files_recursive()
 
354
        return SmartServerResponse(('names',) + tuple(filenames))
 
355
 
 
356
    def do_list_dir(self, relpath):
 
357
        filenames = self._backing_transport.list_dir(relpath)
 
358
        return SmartServerResponse(('names',) + tuple(filenames))
 
359
 
 
360
    def do_mkdir(self, relpath, mode):
 
361
        self._backing_transport.mkdir(relpath,
 
362
                                      self._deserialise_optional_mode(mode))
 
363
 
 
364
    def do_move(self, rel_from, rel_to):
 
365
        self._backing_transport.move(rel_from, rel_to)
 
366
 
 
367
    def do_put(self, relpath, mode):
 
368
        self._backing_transport.put_bytes(relpath,
 
369
                self._recv_body(),
 
370
                self._deserialise_optional_mode(mode))
 
371
 
 
372
    @staticmethod
 
373
    def _deserialise_offsets(text):
 
374
        offsets = []
 
375
        for line in text.split('\n'):
 
376
            if not line:
 
377
                continue
 
378
            start, length = line.split(',')
 
379
            offsets.append((int(start), int(length)))
 
380
        return offsets
 
381
 
 
382
    def do_readv(self, relpath):
 
383
        offsets = self._deserialise_offsets(self._recv_body())
 
384
        backing_bytes = ''.join(bytes for offset, bytes in
 
385
                             self._backing_transport.readv(relpath, offsets))
 
386
        return SmartServerResponse(('readv',), backing_bytes)
 
387
        
 
388
    def do_rename(self, rel_from, rel_to):
 
389
        self._backing_transport.rename(rel_from, rel_to)
 
390
 
 
391
    def do_rmdir(self, relpath):
 
392
        self._backing_transport.rmdir(relpath)
 
393
 
 
394
    def do_stat(self, relpath):
 
395
        stat = self._backing_transport.stat(relpath)
 
396
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
 
397
        
 
398
    def do_get_bundle(self, path, revision_id):
 
399
        # open transport relative to our base
 
400
        t = self._backing_transport.clone(path)
 
401
        control, extra_path = bzrdir.BzrDir.open_containing_from_transport(t)
 
402
        repo = control.open_repository()
 
403
        tmpf = tempfile.TemporaryFile()
 
404
        base_revision = revision.NULL_REVISION
 
405
        write_bundle(repo, revision_id, base_revision, tmpf)
 
406
        tmpf.seek(0)
 
407
        return SmartServerResponse((), tmpf.read())
 
408
 
 
409
    def dispatch_command(self, cmd, args):
 
410
        func = getattr(self, 'do_' + cmd, None)
 
411
        if func is None:
 
412
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
 
413
        try:
 
414
            result = func(*args)
 
415
            if result is None: 
 
416
                result = SmartServerResponse(('ok',))
 
417
            return result
 
418
        except errors.NoSuchFile, e:
 
419
            return SmartServerResponse(('NoSuchFile', e.path))
 
420
        except errors.FileExists, e:
 
421
            return SmartServerResponse(('FileExists', e.path))
 
422
        except errors.DirectoryNotEmpty, e:
 
423
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
 
424
        except errors.ShortReadvError, e:
 
425
            return SmartServerResponse(('ShortReadvError',
 
426
                e.path, str(e.offset), str(e.length), str(e.actual)))
 
427
        except UnicodeError, e:
 
428
            # If it is a DecodeError, than most likely we are starting
 
429
            # with a plain string
 
430
            str_or_unicode = e.object
 
431
            if isinstance(str_or_unicode, unicode):
 
432
                val = u'u:' + str_or_unicode
 
433
            else:
 
434
                val = u's:' + str_or_unicode.encode('base64')
 
435
            # This handles UnicodeEncodeError or UnicodeDecodeError
 
436
            return SmartServerResponse((e.__class__.__name__,
 
437
                    e.encoding, val, str(e.start), str(e.end), e.reason))
 
438
 
 
439
 
 
440
class SmartTCPServer(object):
 
441
    """Listens on a TCP socket and accepts connections from smart clients"""
 
442
 
 
443
    def __init__(self, backing_transport=None, host='127.0.0.1', port=0):
 
444
        """Construct a new server.
 
445
 
 
446
        To actually start it running, call either start_background_thread or
 
447
        serve.
 
448
 
 
449
        :param host: Name of the interface to listen on.
 
450
        :param port: TCP port to listen on, or 0 to allocate a transient port.
 
451
        """
 
452
        if backing_transport is None:
 
453
            backing_transport = memory.MemoryTransport()
 
454
        self._server_socket = socket.socket()
 
455
        self._server_socket.bind((host, port))
 
456
        self.port = self._server_socket.getsockname()[1]
 
457
        self._server_socket.listen(1)
 
458
        self._server_socket.settimeout(1)
 
459
        self.backing_transport = backing_transport
 
460
 
 
461
    def serve(self):
 
462
        # let connections timeout so that we get a chance to terminate
 
463
        # Keep a reference to the exceptions we want to catch because the socket
 
464
        # module's globals get set to None during interpreter shutdown.
 
465
        from socket import timeout as socket_timeout
 
466
        from socket import error as socket_error
 
467
        self._should_terminate = False
 
468
        while not self._should_terminate:
 
469
            try:
 
470
                self.accept_and_serve()
 
471
            except socket_timeout:
 
472
                # just check if we're asked to stop
 
473
                pass
 
474
            except socket_error, e:
 
475
                trace.warning("client disconnected: %s", e)
 
476
                pass
 
477
 
 
478
    def get_url(self):
 
479
        """Return the url of the server"""
 
480
        return "bzr://%s:%d/" % self._server_socket.getsockname()
 
481
 
 
482
    def accept_and_serve(self):
 
483
        conn, client_addr = self._server_socket.accept()
 
484
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
485
        from_client = conn.makefile('r')
 
486
        to_client = conn.makefile('w')
 
487
        handler = SmartStreamServer(from_client, to_client,
 
488
                self.backing_transport)
 
489
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
 
490
        connection_thread.setDaemon(True)
 
491
        connection_thread.start()
 
492
 
 
493
    def start_background_thread(self):
 
494
        self._server_thread = threading.Thread(None,
 
495
                self.serve,
 
496
                name='server-' + self.get_url())
 
497
        self._server_thread.setDaemon(True)
 
498
        self._server_thread.start()
 
499
 
 
500
    def stop_background_thread(self):
 
501
        self._should_terminate = True
 
502
        # self._server_socket.close()
 
503
        # we used to join the thread, but it's not really necessary; it will
 
504
        # terminate in time
 
505
        ## self._server_thread.join()
 
506
 
 
507
 
 
508
class SmartTCPServer_for_testing(SmartTCPServer):
 
509
    """Server suitable for use by transport tests.
 
510
    
 
511
    This server is backed by the process's cwd.
 
512
    """
 
513
 
 
514
    def __init__(self):
 
515
        self._homedir = os.getcwd()
 
516
        # The server is set up by default like for ssh access: the client
 
517
        # passes filesystem-absolute paths; therefore the server must look
 
518
        # them up relative to the root directory.  it might be better to act
 
519
        # a public server and have the server rewrite paths into the test
 
520
        # directory.
 
521
        SmartTCPServer.__init__(self, transport.get_transport("file:///"))
 
522
        
 
523
    def setUp(self):
 
524
        """Set up server for testing"""
 
525
        self.start_background_thread()
 
526
 
 
527
    def tearDown(self):
 
528
        self.stop_background_thread()
 
529
 
 
530
    def get_url(self):
 
531
        """Return the url of the server"""
 
532
        host, port = self._server_socket.getsockname()
 
533
        # XXX: I think this is likely to break on windows -- self._homedir will
 
534
        # have backslashes (and maybe a drive letter?).
 
535
        #  -- Andrew Bennetts, 2006-08-29
 
536
        return "bzr://%s:%d%s" % (host, port, urlutils.escape(self._homedir))
 
537
 
 
538
    def get_bogus_url(self):
 
539
        """Return a URL which will fail to connect"""
 
540
        return 'bzr://127.0.0.1:1/'
 
541
 
 
542
 
 
543
class SmartStat(object):
 
544
 
 
545
    def __init__(self, size, mode):
 
546
        self.st_size = size
 
547
        self.st_mode = mode
 
548
 
 
549
 
 
550
class SmartTransport(transport.Transport):
 
551
    """Connection to a smart server.
 
552
 
 
553
    The connection holds references to pipes that can be used to send requests
 
554
    to the server.
 
555
 
 
556
    The connection has a notion of the current directory to which it's
 
557
    connected; this is incorporated in filenames passed to the server.
 
558
    
 
559
    This supports some higher-level RPC operations and can also be treated 
 
560
    like a Transport to do file-like operations.
 
561
 
 
562
    The connection can be made over a tcp socket, or (in future) an ssh pipe
 
563
    or a series of http requests.  There are concrete subclasses for each
 
564
    type: SmartTCPTransport, etc.
 
565
    """
 
566
 
 
567
    def __init__(self, url, clone_from=None, client=None):
 
568
        """Constructor.
 
569
 
 
570
        :param client: ignored when clone_from is not None.
 
571
        """
 
572
        ### Technically super() here is faulty because Transport's __init__
 
573
        ### fails to take 2 parameters, and if super were to choose a silly
 
574
        ### initialisation order things would blow up. 
 
575
        if not url.endswith('/'):
 
576
            url += '/'
 
577
        super(SmartTransport, self).__init__(url)
 
578
        self._scheme, self._username, self._password, self._host, self._port, self._path = \
 
579
                transport.split_url(url)
 
580
        if clone_from is None:
 
581
            if client is None:
 
582
                self._client = SmartStreamClient(self._connect_to_server)
 
583
            else:
 
584
                self._client = client
 
585
        else:
 
586
            # credentials may be stripped from the base in some circumstances
 
587
            # as yet to be clearly defined or documented, so copy them.
 
588
            self._username = clone_from._username
 
589
            # reuse same connection
 
590
            self._client = clone_from._client
 
591
 
 
592
    def abspath(self, relpath):
 
593
        """Return the full url to the given relative path.
 
594
        
 
595
        @param relpath: the relative path or path components
 
596
        @type relpath: str or list
 
597
        """
 
598
        return self._unparse_url(self._remote_path(relpath))
 
599
    
 
600
    def clone(self, relative_url):
 
601
        """Make a new SmartTransport related to me, sharing the same connection.
 
602
 
 
603
        This essentially opens a handle on a different remote directory.
 
604
        """
 
605
        if relative_url is None:
 
606
            return self.__class__(self.base, self)
 
607
        else:
 
608
            return self.__class__(self.abspath(relative_url), self)
 
609
 
 
610
    def is_readonly(self):
 
611
        """Smart server transport can do read/write file operations."""
 
612
        return False
 
613
                                                   
 
614
    def get_smart_client(self):
 
615
        return self._client
 
616
                                                   
 
617
    def _unparse_url(self, path):
 
618
        """Return URL for a path.
 
619
 
 
620
        :see: SFTPUrlHandling._unparse_url
 
621
        """
 
622
        # TODO: Eventually it should be possible to unify this with
 
623
        # SFTPUrlHandling._unparse_url?
 
624
        if path == '':
 
625
            path = '/'
 
626
        path = urllib.quote(path)
 
627
        netloc = urllib.quote(self._host)
 
628
        if self._username is not None:
 
629
            netloc = '%s@%s' % (urllib.quote(self._username), netloc)
 
630
        if self._port is not None:
 
631
            netloc = '%s:%d' % (netloc, self._port)
 
632
        return urlparse.urlunparse((self._scheme, netloc, path, '', '', ''))
 
633
 
 
634
    def _remote_path(self, relpath):
 
635
        """Returns the Unicode version of the absolute path for relpath."""
 
636
        return self._combine_paths(self._path, relpath)
 
637
 
 
638
    def has(self, relpath):
 
639
        """Indicate whether a remote file of the given name exists or not.
 
640
 
 
641
        :see: Transport.has()
 
642
        """
 
643
        resp = self._client._call('has', self._remote_path(relpath))
 
644
        if resp == ('yes', ):
 
645
            return True
 
646
        elif resp == ('no', ):
 
647
            return False
 
648
        else:
 
649
            self._translate_error(resp)
 
650
 
 
651
    def get(self, relpath):
 
652
        """Return file-like object reading the contents of a remote file.
 
653
        
 
654
        :see: Transport.get_bytes()/get_file()
 
655
        """
 
656
        remote = self._remote_path(relpath)
 
657
        resp = self._client._call('get', remote)
 
658
        if resp != ('ok', ):
 
659
            self._translate_error(resp, relpath)
 
660
        return StringIO(self._client._recv_bulk())
 
661
 
 
662
    def _serialise_optional_mode(self, mode):
 
663
        if mode is None:
 
664
            return ''
 
665
        else:
 
666
            return '%d' % mode
 
667
 
 
668
    def mkdir(self, relpath, mode=None):
 
669
        resp = self._client._call('mkdir', 
 
670
                                  self._remote_path(relpath), 
 
671
                                  self._serialise_optional_mode(mode))
 
672
        self._translate_error(resp)
 
673
 
 
674
    def put_file(self, relpath, upload_file, mode=None):
 
675
        # its not ideal to seek back, but currently put_non_atomic_file depends
 
676
        # on transports not reading before failing - which is a faulty
 
677
        # assumption I think - RBC 20060915
 
678
        pos = upload_file.tell()
 
679
        try:
 
680
            return self.put_bytes(relpath, upload_file.read(), mode)
 
681
        except:
 
682
            upload_file.seek(pos)
 
683
            raise
 
684
 
 
685
    def put_bytes(self, relpath, upload_contents, mode=None):
 
686
        # FIXME: upload_file is probably not safe for non-ascii characters -
 
687
        # should probably just pass all parameters as length-delimited
 
688
        # strings?
 
689
        resp = self._client._call_with_upload(
 
690
            'put',
 
691
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
692
            upload_contents)
 
693
        self._translate_error(resp)
 
694
 
 
695
    def append_file(self, relpath, from_file, mode=None):
 
696
        return self.append_bytes(relpath, from_file.read(), mode)
 
697
        
 
698
    def append_bytes(self, relpath, bytes, mode=None):
 
699
        resp = self._client._call_with_upload(
 
700
            'append',
 
701
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
702
            bytes)
 
703
        if resp[0] == 'appended':
 
704
            return int(resp[1])
 
705
        self._translate_error(resp)
 
706
 
 
707
    def delete(self, relpath):
 
708
        resp = self._client._call('delete', self._remote_path(relpath))
 
709
        self._translate_error(resp)
 
710
 
 
711
    @staticmethod
 
712
    def _serialise_offsets(offsets):
 
713
        txt = []
 
714
        for start, length in offsets:
 
715
            txt.append('%d,%d' % (start, length))
 
716
        return '\n'.join(txt)
 
717
 
 
718
    def readv(self, relpath, offsets):
 
719
        if not offsets:
 
720
            return
 
721
 
 
722
        offsets = list(offsets)
 
723
        resp = self._client._call_with_upload(
 
724
            'readv',
 
725
            (self._remote_path(relpath),),
 
726
            self._serialise_offsets(offsets))
 
727
 
 
728
        if resp[0] != 'readv':
 
729
            self._translate_error(resp)
 
730
        else:
 
731
            data = self._client._recv_bulk()
 
732
            cur_pos = 0
 
733
            for start, length in offsets:
 
734
                next_pos = cur_pos + length
 
735
                if len(data) < next_pos:
 
736
                    raise errors.ShortReadvError(relpath, start, length,
 
737
                                                 actual=len(data)-cur_pos)
 
738
                cur_data = data[cur_pos:next_pos]
 
739
                cur_pos = next_pos
 
740
                yield start, cur_data
 
741
 
 
742
    def rename(self, rel_from, rel_to):
 
743
        self._call('rename', 
 
744
                   self._remote_path(rel_from),
 
745
                   self._remote_path(rel_to))
 
746
 
 
747
    def move(self, rel_from, rel_to):
 
748
        self._call('move', 
 
749
                   self._remote_path(rel_from),
 
750
                   self._remote_path(rel_to))
 
751
 
 
752
    def rmdir(self, relpath):
 
753
        resp = self._call('rmdir', self._remote_path(relpath))
 
754
 
 
755
    def _call(self, method, *args):
 
756
        resp = self._client._call(method, *args)
 
757
        self._translate_error(resp)
 
758
 
 
759
    def _translate_error(self, resp, orig_path=None):
 
760
        """Raise an exception from a response"""
 
761
        what = resp[0]
 
762
        if what == 'ok':
 
763
            return
 
764
        elif what == 'NoSuchFile':
 
765
            if orig_path is not None:
 
766
                error_path = orig_path
 
767
            else:
 
768
                error_path = resp[1]
 
769
            raise errors.NoSuchFile(error_path)
 
770
        elif what == 'error':
 
771
            raise errors.SmartProtocolError(unicode(resp[1]))
 
772
        elif what == 'FileExists':
 
773
            raise errors.FileExists(resp[1])
 
774
        elif what == 'DirectoryNotEmpty':
 
775
            raise errors.DirectoryNotEmpty(resp[1])
 
776
        elif what == 'ShortReadvError':
 
777
            raise errors.ShortReadvError(resp[1], int(resp[2]),
 
778
                                         int(resp[3]), int(resp[4]))
 
779
        elif what in ('UnicodeEncodeError', 'UnicodeDecodeError'):
 
780
            encoding = str(resp[1]) # encoding must always be a string
 
781
            val = resp[2]
 
782
            start = int(resp[3])
 
783
            end = int(resp[4])
 
784
            reason = str(resp[5]) # reason must always be a string
 
785
            if val.startswith('u:'):
 
786
                val = val[2:]
 
787
            elif val.startswith('s:'):
 
788
                val = val[2:].decode('base64')
 
789
            if what == 'UnicodeDecodeError':
 
790
                raise UnicodeDecodeError(encoding, val, start, end, reason)
 
791
            elif what == 'UnicodeEncodeError':
 
792
                raise UnicodeEncodeError(encoding, val, start, end, reason)
 
793
        else:
 
794
            raise errors.SmartProtocolError('unexpected smart server error: %r' % (resp,))
 
795
 
 
796
    def _send_tuple(self, args):
 
797
        self._client._send_tuple(args)
 
798
 
 
799
    def _recv_tuple(self):
 
800
        return self._client._recv_tuple()
 
801
 
 
802
    def disconnect(self):
 
803
        self._client.disconnect()
 
804
 
 
805
    def delete_tree(self, relpath):
 
806
        raise errors.TransportNotPossible('readonly transport')
 
807
 
 
808
    def stat(self, relpath):
 
809
        resp = self._client._call('stat', self._remote_path(relpath))
 
810
        if resp[0] == 'stat':
 
811
            return SmartStat(int(resp[1]), int(resp[2], 8))
 
812
        else:
 
813
            self._translate_error(resp)
 
814
 
 
815
    ## def lock_read(self, relpath):
 
816
    ##     """Lock the given file for shared (read) access.
 
817
    ##     :return: A lock object, which should be passed to Transport.unlock()
 
818
    ##     """
 
819
    ##     # The old RemoteBranch ignore lock for reading, so we will
 
820
    ##     # continue that tradition and return a bogus lock object.
 
821
    ##     class BogusLock(object):
 
822
    ##         def __init__(self, path):
 
823
    ##             self.path = path
 
824
    ##         def unlock(self):
 
825
    ##             pass
 
826
    ##     return BogusLock(relpath)
 
827
 
 
828
    def listable(self):
 
829
        return True
 
830
 
 
831
    def list_dir(self, relpath):
 
832
        resp = self._client._call('list_dir',
 
833
                                  self._remote_path(relpath))
 
834
        if resp[0] == 'names':
 
835
            return [name.encode('ascii') for name in resp[1:]]
 
836
        else:
 
837
            self._translate_error(resp)
 
838
 
 
839
    def iter_files_recursive(self):
 
840
        resp = self._client._call('iter_files_recursive',
 
841
                                  self._remote_path(''))
 
842
        if resp[0] == 'names':
 
843
            return resp[1:]
 
844
        else:
 
845
            self._translate_error(resp)
 
846
 
 
847
 
 
848
class SmartStreamClient(SmartProtocolBase):
 
849
    """Connection to smart server over two streams"""
 
850
 
 
851
    def __init__(self, connect_func):
 
852
        self._connect_func = connect_func
 
853
        self._connected = False
 
854
 
 
855
    def __del__(self):
 
856
        self.disconnect()
 
857
 
 
858
    def _ensure_connection(self):
 
859
        if not self._connected:
 
860
            self._in, self._out = self._connect_func()
 
861
            self._connected = True
 
862
 
 
863
    def _send_tuple(self, args):
 
864
        self._ensure_connection()
 
865
        _send_tuple(self._out, args)
 
866
 
 
867
    def _send_bulk_data(self, body):
 
868
        self._ensure_connection()
 
869
        SmartProtocolBase._send_bulk_data(self, body)
 
870
        
 
871
    def _recv_bulk(self):
 
872
        self._ensure_connection()
 
873
        return SmartProtocolBase._recv_bulk(self)
 
874
 
 
875
    def _recv_tuple(self):
 
876
        self._ensure_connection()
 
877
        return SmartProtocolBase._recv_tuple(self)
 
878
 
 
879
    def _recv_trailer(self):
 
880
        self._ensure_connection()
 
881
        return SmartProtocolBase._recv_trailer(self)
 
882
 
 
883
    def disconnect(self):
 
884
        """Close connection to the server"""
 
885
        if self._connected:
 
886
            self._out.close()
 
887
            self._in.close()
 
888
 
 
889
    def _call(self, *args):
 
890
        self._send_tuple(args)
 
891
        return self._recv_tuple()
 
892
 
 
893
    def _call_with_upload(self, method, args, body):
 
894
        """Call an rpc, supplying bulk upload data.
 
895
 
 
896
        :param method: method name to call
 
897
        :param args: parameter args tuple
 
898
        :param body: upload body as a byte string
 
899
        """
 
900
        self._send_tuple((method,) + args)
 
901
        self._send_bulk_data(body)
 
902
        return self._recv_tuple()
 
903
 
 
904
    def query_version(self):
 
905
        """Return protocol version number of the server."""
 
906
        # XXX: should make sure it's empty
 
907
        self._send_tuple(('hello',))
 
908
        resp = self._recv_tuple()
 
909
        if resp == ('ok', '1'):
 
910
            return 1
 
911
        else:
 
912
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
913
 
 
914
 
 
915
class SmartTCPTransport(SmartTransport):
 
916
    """Connection to smart server over plain tcp"""
 
917
 
 
918
    def __init__(self, url, clone_from=None):
 
919
        super(SmartTCPTransport, self).__init__(url, clone_from)
 
920
        try:
 
921
            self._port = int(self._port)
 
922
        except (ValueError, TypeError), e:
 
923
            raise errors.InvalidURL(path=url, extra="invalid port %s" % self._port)
 
924
        self._socket = None
 
925
 
 
926
    def _connect_to_server(self):
 
927
        self._socket = socket.socket()
 
928
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
929
        result = self._socket.connect_ex((self._host, int(self._port)))
 
930
        if result:
 
931
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
932
                    (self._host, self._port, os.strerror(result)))
 
933
        # TODO: May be more efficient to just treat them as sockets
 
934
        # throughout?  But what about pipes to ssh?...
 
935
        to_server = self._socket.makefile('w')
 
936
        from_server = self._socket.makefile('r')
 
937
        return from_server, to_server
 
938
 
 
939
    def disconnect(self):
 
940
        super(SmartTCPTransport, self).disconnect()
 
941
        # XXX: Is closing the socket as well as closing the files really
 
942
        # necessary?
 
943
        if self._socket is not None:
 
944
            self._socket.close()
 
945
 
 
946
try:
 
947
    from bzrlib.transport import sftp
 
948
except errors.ParamikoNotPresent:
 
949
    # no paramiko, no SSHTransport.
 
950
    pass
 
951
else:
 
952
    class SmartSSHTransport(SmartTransport):
 
953
        """Connection to smart server over SSH."""
 
954
 
 
955
        def __init__(self, url, clone_from=None):
 
956
            # TODO: all this probably belongs in the parent class.
 
957
            super(SmartSSHTransport, self).__init__(url, clone_from)
 
958
            try:
 
959
                if self._port is not None:
 
960
                    self._port = int(self._port)
 
961
            except (ValueError, TypeError), e:
 
962
                raise errors.InvalidURL(path=url, extra="invalid port %s" % self._port)
 
963
 
 
964
        def _connect_to_server(self):
 
965
            # XXX: don't hardcode vendor
 
966
            # XXX: cannot pass password to SSHSubprocess yet
 
967
            if self._password is not None:
 
968
                raise errors.InvalidURL("SSH smart transport doesn't handle passwords")
 
969
            self._ssh_connection = sftp.SSHSubprocess(self._host, 'openssh',
 
970
                    port=self._port, user=self._username,
 
971
                    command=['bzr', 'serve', '--inet'])
 
972
            return self._ssh_connection.get_filelike_channels()
 
973
 
 
974
        def disconnect(self):
 
975
            super(SmartSSHTransport, self).disconnect()
 
976
            self._ssh_connection.close()
 
977
 
 
978
 
 
979
def get_test_permutations():
 
980
    """Return (transport, server) permutations for testing"""
 
981
    return [(SmartTCPTransport, SmartTCPServer_for_testing)]