/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart.py

  • Committer: Aaron Bentley
  • Date: 2007-04-10 21:05:17 UTC
  • mto: (1551.19.24 Aaron's mergeable stuff)
  • mto: This revision was merged to the branch mainline in revision 2405.
  • Revision ID: abentley@panoramicfeedback.com-20070410210517-0m7mhl5d2fhs66u5
Move cat-revision tests out of test_revision_info

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Smart-server protocol, client and server.
 
18
 
 
19
Requests are sent as a command and list of arguments, followed by optional
 
20
bulk body data.  Responses are similarly a response and list of arguments,
 
21
followed by bulk body data. ::
 
22
 
 
23
  SEP := '\001'
 
24
    Fields are separated by Ctrl-A.
 
25
  BULK_DATA := CHUNK TRAILER
 
26
    Chunks can be repeated as many times as necessary.
 
27
  CHUNK := CHUNK_LEN CHUNK_BODY
 
28
  CHUNK_LEN := DIGIT+ NEWLINE
 
29
    Gives the number of bytes in the following chunk.
 
30
  CHUNK_BODY := BYTE[chunk_len]
 
31
  TRAILER := SUCCESS_TRAILER | ERROR_TRAILER
 
32
  SUCCESS_TRAILER := 'done' NEWLINE
 
33
  ERROR_TRAILER := 
 
34
 
 
35
Paths are passed across the network.  The client needs to see a namespace that
 
36
includes any repository that might need to be referenced, and the client needs
 
37
to know about a root directory beyond which it cannot ascend.
 
38
 
 
39
Servers run over ssh will typically want to be able to access any path the user 
 
40
can access.  Public servers on the other hand (which might be over http, ssh
 
41
or tcp) will typically want to restrict access to only a particular directory 
 
42
and its children, so will want to do a software virtual root at that level.
 
43
In other words they'll want to rewrite incoming paths to be under that level
 
44
(and prevent escaping using ../ tricks.)
 
45
 
 
46
URLs that include ~ should probably be passed across to the server verbatim
 
47
and the server can expand them.  This will proably not be meaningful when 
 
48
limited to a directory?
 
49
 
 
50
At the bottom level socket, pipes, HTTP server.  For sockets, we have the idea
 
51
that you have multiple requests and get a read error because the other side did
 
52
shutdown.  For pipes we have read pipe which will have a zero read which marks
 
53
end-of-file.  For HTTP server environment there is not end-of-stream because
 
54
each request coming into the server is independent.
 
55
 
 
56
So we need a wrapper around pipes and sockets to seperate out requests from
 
57
substrate and this will give us a single model which is consist for HTTP,
 
58
sockets and pipes.
 
59
 
 
60
Server-side
 
61
-----------
 
62
 
 
63
 MEDIUM  (factory for protocol, reads bytes & pushes to protocol,
 
64
          uses protocol to detect end-of-request, sends written
 
65
          bytes to client) e.g. socket, pipe, HTTP request handler.
 
66
  ^
 
67
  | bytes.
 
68
  v
 
69
 
 
70
PROTOCOL  (serialization, deserialization)  accepts bytes for one
 
71
          request, decodes according to internal state, pushes
 
72
          structured data to handler.  accepts structured data from
 
73
          handler and encodes and writes to the medium.  factory for
 
74
          handler.
 
75
  ^
 
76
  | structured data
 
77
  v
 
78
 
 
79
HANDLER   (domain logic) accepts structured data, operates state
 
80
          machine until the request can be satisfied,
 
81
          sends structured data to the protocol.
 
82
 
 
83
 
 
84
Client-side
 
85
-----------
 
86
 
 
87
 CLIENT             domain logic, accepts domain requests, generated structured
 
88
                    data, reads structured data from responses and turns into
 
89
                    domain data.  Sends structured data to the protocol.
 
90
                    Operates state machines until the request can be delivered
 
91
                    (e.g. reading from a bundle generated in bzrlib to deliver a
 
92
                    complete request).
 
93
 
 
94
                    Possibly this should just be RemoteBzrDir, RemoteTransport,
 
95
                    ...
 
96
  ^
 
97
  | structured data
 
98
  v
 
99
 
 
100
PROTOCOL  (serialization, deserialization)  accepts structured data for one
 
101
          request, encodes and writes to the medium.  Reads bytes from the
 
102
          medium, decodes and allows the client to read structured data.
 
103
  ^
 
104
  | bytes.
 
105
  v
 
106
 
 
107
 MEDIUM  (accepts bytes from the protocol & delivers to the remote server.
 
108
          Allows the potocol to read bytes e.g. socket, pipe, HTTP request.
 
109
"""
 
110
 
 
111
 
 
112
# TODO: _translate_error should be on the client, not the transport because
 
113
#     error coding is wire protocol specific.
 
114
 
 
115
# TODO: A plain integer from query_version is too simple; should give some
 
116
# capabilities too?
 
117
 
 
118
# TODO: Server should probably catch exceptions within itself and send them
 
119
# back across the network.  (But shouldn't catch KeyboardInterrupt etc)
 
120
# Also needs to somehow report protocol errors like bad requests.  Need to
 
121
# consider how we'll handle error reporting, e.g. if we get halfway through a
 
122
# bulk transfer and then something goes wrong.
 
123
 
 
124
# TODO: Standard marker at start of request/response lines?
 
125
 
 
126
# TODO: Make each request and response self-validatable, e.g. with checksums.
 
127
#
 
128
# TODO: get/put objects could be changed to gradually read back the data as it
 
129
# comes across the network
 
130
#
 
131
# TODO: What should the server do if it hits an error and has to terminate?
 
132
#
 
133
# TODO: is it useful to allow multiple chunks in the bulk data?
 
134
#
 
135
# TODO: If we get an exception during transmission of bulk data we can't just
 
136
# emit the exception because it won't be seen.
 
137
#   John proposes:  I think it would be worthwhile to have a header on each
 
138
#   chunk, that indicates it is another chunk. Then you can send an 'error'
 
139
#   chunk as long as you finish the previous chunk.
 
140
#
 
141
# TODO: Clone method on Transport; should work up towards parent directory;
 
142
# unclear how this should be stored or communicated to the server... maybe
 
143
# just pass it on all relevant requests?
 
144
#
 
145
# TODO: Better name than clone() for changing between directories.  How about
 
146
# open_dir or change_dir or chdir?
 
147
#
 
148
# TODO: Is it really good to have the notion of current directory within the
 
149
# connection?  Perhaps all Transports should factor out a common connection
 
150
# from the thing that has the directory context?
 
151
#
 
152
# TODO: Pull more things common to sftp and ssh to a higher level.
 
153
#
 
154
# TODO: The server that manages a connection should be quite small and retain
 
155
# minimum state because each of the requests are supposed to be stateless.
 
156
# Then we can write another implementation that maps to http.
 
157
#
 
158
# TODO: What to do when a client connection is garbage collected?  Maybe just
 
159
# abruptly drop the connection?
 
160
#
 
161
# TODO: Server in some cases will need to restrict access to files outside of
 
162
# a particular root directory.  LocalTransport doesn't do anything to stop you
 
163
# ascending above the base directory, so we need to prevent paths
 
164
# containing '..' in either the server or transport layers.  (Also need to
 
165
# consider what happens if someone creates a symlink pointing outside the 
 
166
# directory tree...)
 
167
#
 
168
# TODO: Server should rebase absolute paths coming across the network to put
 
169
# them under the virtual root, if one is in use.  LocalTransport currently
 
170
# doesn't do that; if you give it an absolute path it just uses it.
 
171
 
172
# XXX: Arguments can't contain newlines or ascii; possibly we should e.g.
 
173
# urlescape them instead.  Indeed possibly this should just literally be
 
174
# http-over-ssh.
 
175
#
 
176
# FIXME: This transport, with several others, has imperfect handling of paths
 
177
# within urls.  It'd probably be better for ".." from a root to raise an error
 
178
# rather than return the same directory as we do at present.
 
179
#
 
180
# TODO: Rather than working at the Transport layer we want a Branch,
 
181
# Repository or BzrDir objects that talk to a server.
 
182
#
 
183
# TODO: Probably want some way for server commands to gradually produce body
 
184
# data rather than passing it as a string; they could perhaps pass an
 
185
# iterator-like callback that will gradually yield data; it probably needs a
 
186
# close() method that will always be closed to do any necessary cleanup.
 
187
#
 
188
# TODO: Split the actual smart server from the ssh encoding of it.
 
189
#
 
190
# TODO: Perhaps support file-level readwrite operations over the transport
 
191
# too.
 
192
#
 
193
# TODO: SmartBzrDir class, proxying all Branch etc methods across to another
 
194
# branch doing file-level operations.
 
195
#
 
196
 
 
197
from cStringIO import StringIO
 
198
import errno
 
199
import os
 
200
import socket
 
201
import sys
 
202
import tempfile
 
203
import threading
 
204
import urllib
 
205
import urlparse
 
206
 
 
207
from bzrlib import (
 
208
    bzrdir,
 
209
    errors,
 
210
    revision,
 
211
    transport,
 
212
    trace,
 
213
    urlutils,
 
214
    )
 
215
from bzrlib.bundle.serializer import write_bundle
 
216
from bzrlib.hooks import Hooks
 
217
try:
 
218
    from bzrlib.transport import ssh
 
219
except errors.ParamikoNotPresent:
 
220
    # no paramiko.  SmartSSHClientMedium will break.
 
221
    pass
 
222
 
 
223
# must do this otherwise urllib can't parse the urls properly :(
 
224
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh', 'bzr+http']:
 
225
    transport.register_urlparse_netloc_protocol(scheme)
 
226
del scheme
 
227
 
 
228
 
 
229
# Port 4155 is the default port for bzr://, registered with IANA.
 
230
BZR_DEFAULT_PORT = 4155
 
231
 
 
232
 
 
233
def _recv_tuple(from_file):
 
234
    req_line = from_file.readline()
 
235
    return _decode_tuple(req_line)
 
236
 
 
237
 
 
238
def _decode_tuple(req_line):
 
239
    if req_line == None or req_line == '':
 
240
        return None
 
241
    if req_line[-1] != '\n':
 
242
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
243
    return tuple(req_line[:-1].split('\x01'))
 
244
 
 
245
 
 
246
def _encode_tuple(args):
 
247
    """Encode the tuple args to a bytestream."""
 
248
    return '\x01'.join(args) + '\n'
 
249
 
 
250
 
 
251
class SmartProtocolBase(object):
 
252
    """Methods common to client and server"""
 
253
 
 
254
    # TODO: this only actually accomodates a single block; possibly should
 
255
    # support multiple chunks?
 
256
    def _encode_bulk_data(self, body):
 
257
        """Encode body as a bulk data chunk."""
 
258
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
259
 
 
260
    def _serialise_offsets(self, offsets):
 
261
        """Serialise a readv offset list."""
 
262
        txt = []
 
263
        for start, length in offsets:
 
264
            txt.append('%d,%d' % (start, length))
 
265
        return '\n'.join(txt)
 
266
        
 
267
 
 
268
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
269
    """Server-side encoding and decoding logic for smart version 1."""
 
270
    
 
271
    def __init__(self, backing_transport, write_func):
 
272
        self._backing_transport = backing_transport
 
273
        self.excess_buffer = ''
 
274
        self._finished = False
 
275
        self.in_buffer = ''
 
276
        self.has_dispatched = False
 
277
        self.request = None
 
278
        self._body_decoder = None
 
279
        self._write_func = write_func
 
280
 
 
281
    def accept_bytes(self, bytes):
 
282
        """Take bytes, and advance the internal state machine appropriately.
 
283
        
 
284
        :param bytes: must be a byte string
 
285
        """
 
286
        assert isinstance(bytes, str)
 
287
        self.in_buffer += bytes
 
288
        if not self.has_dispatched:
 
289
            if '\n' not in self.in_buffer:
 
290
                # no command line yet
 
291
                return
 
292
            self.has_dispatched = True
 
293
            try:
 
294
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
295
                first_line += '\n'
 
296
                req_args = _decode_tuple(first_line)
 
297
                self.request = SmartServerRequestHandler(
 
298
                    self._backing_transport)
 
299
                self.request.dispatch_command(req_args[0], req_args[1:])
 
300
                if self.request.finished_reading:
 
301
                    # trivial request
 
302
                    self.excess_buffer = self.in_buffer
 
303
                    self.in_buffer = ''
 
304
                    self._send_response(self.request.response.args,
 
305
                        self.request.response.body)
 
306
            except KeyboardInterrupt:
 
307
                raise
 
308
            except Exception, exception:
 
309
                # everything else: pass to client, flush, and quit
 
310
                self._send_response(('error', str(exception)))
 
311
                return
 
312
 
 
313
        if self.has_dispatched:
 
314
            if self._finished:
 
315
                # nothing to do.XXX: this routine should be a single state 
 
316
                # machine too.
 
317
                self.excess_buffer += self.in_buffer
 
318
                self.in_buffer = ''
 
319
                return
 
320
            if self._body_decoder is None:
 
321
                self._body_decoder = LengthPrefixedBodyDecoder()
 
322
            self._body_decoder.accept_bytes(self.in_buffer)
 
323
            self.in_buffer = self._body_decoder.unused_data
 
324
            body_data = self._body_decoder.read_pending_data()
 
325
            self.request.accept_body(body_data)
 
326
            if self._body_decoder.finished_reading:
 
327
                self.request.end_of_body()
 
328
                assert self.request.finished_reading, \
 
329
                    "no more body, request not finished"
 
330
            if self.request.response is not None:
 
331
                self._send_response(self.request.response.args,
 
332
                    self.request.response.body)
 
333
                self.excess_buffer = self.in_buffer
 
334
                self.in_buffer = ''
 
335
            else:
 
336
                assert not self.request.finished_reading, \
 
337
                    "no response and we have finished reading."
 
338
 
 
339
    def _send_response(self, args, body=None):
 
340
        """Send a smart server response down the output stream."""
 
341
        assert not self._finished, 'response already sent'
 
342
        self._finished = True
 
343
        self._write_func(_encode_tuple(args))
 
344
        if body is not None:
 
345
            assert isinstance(body, str), 'body must be a str'
 
346
            bytes = self._encode_bulk_data(body)
 
347
            self._write_func(bytes)
 
348
 
 
349
    def next_read_size(self):
 
350
        if self._finished:
 
351
            return 0
 
352
        if self._body_decoder is None:
 
353
            return 1
 
354
        else:
 
355
            return self._body_decoder.next_read_size()
 
356
 
 
357
 
 
358
class LengthPrefixedBodyDecoder(object):
 
359
    """Decodes the length-prefixed bulk data."""
 
360
    
 
361
    def __init__(self):
 
362
        self.bytes_left = None
 
363
        self.finished_reading = False
 
364
        self.unused_data = ''
 
365
        self.state_accept = self._state_accept_expecting_length
 
366
        self.state_read = self._state_read_no_data
 
367
        self._in_buffer = ''
 
368
        self._trailer_buffer = ''
 
369
    
 
370
    def accept_bytes(self, bytes):
 
371
        """Decode as much of bytes as possible.
 
372
 
 
373
        If 'bytes' contains too much data it will be appended to
 
374
        self.unused_data.
 
375
 
 
376
        finished_reading will be set when no more data is required.  Further
 
377
        data will be appended to self.unused_data.
 
378
        """
 
379
        # accept_bytes is allowed to change the state
 
380
        current_state = self.state_accept
 
381
        self.state_accept(bytes)
 
382
        while current_state != self.state_accept:
 
383
            current_state = self.state_accept
 
384
            self.state_accept('')
 
385
 
 
386
    def next_read_size(self):
 
387
        if self.bytes_left is not None:
 
388
            # Ideally we want to read all the remainder of the body and the
 
389
            # trailer in one go.
 
390
            return self.bytes_left + 5
 
391
        elif self.state_accept == self._state_accept_reading_trailer:
 
392
            # Just the trailer left
 
393
            return 5 - len(self._trailer_buffer)
 
394
        elif self.state_accept == self._state_accept_expecting_length:
 
395
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
396
            # 'done\n').
 
397
            return 6
 
398
        else:
 
399
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
400
            return 1
 
401
        
 
402
    def read_pending_data(self):
 
403
        """Return any pending data that has been decoded."""
 
404
        return self.state_read()
 
405
 
 
406
    def _state_accept_expecting_length(self, bytes):
 
407
        self._in_buffer += bytes
 
408
        pos = self._in_buffer.find('\n')
 
409
        if pos == -1:
 
410
            return
 
411
        self.bytes_left = int(self._in_buffer[:pos])
 
412
        self._in_buffer = self._in_buffer[pos+1:]
 
413
        self.bytes_left -= len(self._in_buffer)
 
414
        self.state_accept = self._state_accept_reading_body
 
415
        self.state_read = self._state_read_in_buffer
 
416
 
 
417
    def _state_accept_reading_body(self, bytes):
 
418
        self._in_buffer += bytes
 
419
        self.bytes_left -= len(bytes)
 
420
        if self.bytes_left <= 0:
 
421
            # Finished with body
 
422
            if self.bytes_left != 0:
 
423
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
424
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
425
            self.bytes_left = None
 
426
            self.state_accept = self._state_accept_reading_trailer
 
427
        
 
428
    def _state_accept_reading_trailer(self, bytes):
 
429
        self._trailer_buffer += bytes
 
430
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
431
        # a ProtocolViolation exception?
 
432
        if self._trailer_buffer.startswith('done\n'):
 
433
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
434
            self.state_accept = self._state_accept_reading_unused
 
435
            self.finished_reading = True
 
436
    
 
437
    def _state_accept_reading_unused(self, bytes):
 
438
        self.unused_data += bytes
 
439
 
 
440
    def _state_read_no_data(self):
 
441
        return ''
 
442
 
 
443
    def _state_read_in_buffer(self):
 
444
        result = self._in_buffer
 
445
        self._in_buffer = ''
 
446
        return result
 
447
 
 
448
 
 
449
class SmartServerStreamMedium(object):
 
450
    """Handles smart commands coming over a stream.
 
451
 
 
452
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
453
    in-process fifo for testing.
 
454
 
 
455
    One instance is created for each connected client; it can serve multiple
 
456
    requests in the lifetime of the connection.
 
457
 
 
458
    The server passes requests through to an underlying backing transport, 
 
459
    which will typically be a LocalTransport looking at the server's filesystem.
 
460
    """
 
461
 
 
462
    def __init__(self, backing_transport):
 
463
        """Construct new server.
 
464
 
 
465
        :param backing_transport: Transport for the directory served.
 
466
        """
 
467
        # backing_transport could be passed to serve instead of __init__
 
468
        self.backing_transport = backing_transport
 
469
        self.finished = False
 
470
 
 
471
    def serve(self):
 
472
        """Serve requests until the client disconnects."""
 
473
        # Keep a reference to stderr because the sys module's globals get set to
 
474
        # None during interpreter shutdown.
 
475
        from sys import stderr
 
476
        try:
 
477
            while not self.finished:
 
478
                protocol = SmartServerRequestProtocolOne(self.backing_transport,
 
479
                                                         self._write_out)
 
480
                self._serve_one_request(protocol)
 
481
        except Exception, e:
 
482
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
483
            raise
 
484
 
 
485
    def _serve_one_request(self, protocol):
 
486
        """Read one request from input, process, send back a response.
 
487
        
 
488
        :param protocol: a SmartServerRequestProtocol.
 
489
        """
 
490
        try:
 
491
            self._serve_one_request_unguarded(protocol)
 
492
        except KeyboardInterrupt:
 
493
            raise
 
494
        except Exception, e:
 
495
            self.terminate_due_to_error()
 
496
 
 
497
    def terminate_due_to_error(self):
 
498
        """Called when an unhandled exception from the protocol occurs."""
 
499
        raise NotImplementedError(self.terminate_due_to_error)
 
500
 
 
501
 
 
502
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
503
 
 
504
    def __init__(self, sock, backing_transport):
 
505
        """Constructor.
 
506
 
 
507
        :param sock: the socket the server will read from.  It will be put
 
508
            into blocking mode.
 
509
        """
 
510
        SmartServerStreamMedium.__init__(self, backing_transport)
 
511
        self.push_back = ''
 
512
        sock.setblocking(True)
 
513
        self.socket = sock
 
514
 
 
515
    def _serve_one_request_unguarded(self, protocol):
 
516
        while protocol.next_read_size():
 
517
            if self.push_back:
 
518
                protocol.accept_bytes(self.push_back)
 
519
                self.push_back = ''
 
520
            else:
 
521
                bytes = self.socket.recv(4096)
 
522
                if bytes == '':
 
523
                    self.finished = True
 
524
                    return
 
525
                protocol.accept_bytes(bytes)
 
526
        
 
527
        self.push_back = protocol.excess_buffer
 
528
    
 
529
    def terminate_due_to_error(self):
 
530
        """Called when an unhandled exception from the protocol occurs."""
 
531
        # TODO: This should log to a server log file, but no such thing
 
532
        # exists yet.  Andrew Bennetts 2006-09-29.
 
533
        self.socket.close()
 
534
        self.finished = True
 
535
 
 
536
    def _write_out(self, bytes):
 
537
        self.socket.sendall(bytes)
 
538
 
 
539
 
 
540
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
541
 
 
542
    def __init__(self, in_file, out_file, backing_transport):
 
543
        """Construct new server.
 
544
 
 
545
        :param in_file: Python file from which requests can be read.
 
546
        :param out_file: Python file to write responses.
 
547
        :param backing_transport: Transport for the directory served.
 
548
        """
 
549
        SmartServerStreamMedium.__init__(self, backing_transport)
 
550
        if sys.platform == 'win32':
 
551
            # force binary mode for files
 
552
            import msvcrt
 
553
            for f in (in_file, out_file):
 
554
                fileno = getattr(f, 'fileno', None)
 
555
                if fileno:
 
556
                    msvcrt.setmode(fileno(), os.O_BINARY)
 
557
        self._in = in_file
 
558
        self._out = out_file
 
559
 
 
560
    def _serve_one_request_unguarded(self, protocol):
 
561
        while True:
 
562
            bytes_to_read = protocol.next_read_size()
 
563
            if bytes_to_read == 0:
 
564
                # Finished serving this request.
 
565
                self._out.flush()
 
566
                return
 
567
            bytes = self._in.read(bytes_to_read)
 
568
            if bytes == '':
 
569
                # Connection has been closed.
 
570
                self.finished = True
 
571
                self._out.flush()
 
572
                return
 
573
            protocol.accept_bytes(bytes)
 
574
 
 
575
    def terminate_due_to_error(self):
 
576
        # TODO: This should log to a server log file, but no such thing
 
577
        # exists yet.  Andrew Bennetts 2006-09-29.
 
578
        self._out.close()
 
579
        self.finished = True
 
580
 
 
581
    def _write_out(self, bytes):
 
582
        self._out.write(bytes)
 
583
 
 
584
 
 
585
class SmartServerResponse(object):
 
586
    """Response generated by SmartServerRequestHandler."""
 
587
 
 
588
    def __init__(self, args, body=None):
 
589
        self.args = args
 
590
        self.body = body
 
591
 
 
592
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
 
593
# for delivering the data for a request. This could be done with as the
 
594
# StreamServer, though that would create conflation between request and response
 
595
# which may be undesirable.
 
596
 
 
597
 
 
598
class SmartServerRequestHandler(object):
 
599
    """Protocol logic for smart server.
 
600
    
 
601
    This doesn't handle serialization at all, it just processes requests and
 
602
    creates responses.
 
603
    """
 
604
 
 
605
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServerRequestHandler
 
606
    # not contain encoding or decoding logic to allow the wire protocol to vary
 
607
    # from the object protocol: we will want to tweak the wire protocol separate
 
608
    # from the object model, and ideally we will be able to do that without
 
609
    # having a SmartServerRequestHandler subclass for each wire protocol, rather
 
610
    # just a Protocol subclass.
 
611
 
 
612
    # TODO: Better way of representing the body for commands that take it,
 
613
    # and allow it to be streamed into the server.
 
614
    
 
615
    def __init__(self, backing_transport):
 
616
        self._backing_transport = backing_transport
 
617
        self._converted_command = False
 
618
        self.finished_reading = False
 
619
        self._body_bytes = ''
 
620
        self.response = None
 
621
 
 
622
    def accept_body(self, bytes):
 
623
        """Accept body data.
 
624
 
 
625
        This should be overriden for each command that desired body data to
 
626
        handle the right format of that data. I.e. plain bytes, a bundle etc.
 
627
 
 
628
        The deserialisation into that format should be done in the Protocol
 
629
        object. Set self.desired_body_format to the format your method will
 
630
        handle.
 
631
        """
 
632
        # default fallback is to accumulate bytes.
 
633
        self._body_bytes += bytes
 
634
        
 
635
    def _end_of_body_handler(self):
 
636
        """An unimplemented end of body handler."""
 
637
        raise NotImplementedError(self._end_of_body_handler)
 
638
        
 
639
    def do_hello(self):
 
640
        """Answer a version request with my version."""
 
641
        return SmartServerResponse(('ok', '1'))
 
642
 
 
643
    def do_has(self, relpath):
 
644
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
 
645
        return SmartServerResponse((r,))
 
646
 
 
647
    def do_get(self, relpath):
 
648
        backing_bytes = self._backing_transport.get_bytes(relpath)
 
649
        return SmartServerResponse(('ok',), backing_bytes)
 
650
 
 
651
    def _deserialise_optional_mode(self, mode):
 
652
        # XXX: FIXME this should be on the protocol object.
 
653
        if mode == '':
 
654
            return None
 
655
        else:
 
656
            return int(mode)
 
657
 
 
658
    def do_append(self, relpath, mode):
 
659
        self._converted_command = True
 
660
        self._relpath = relpath
 
661
        self._mode = self._deserialise_optional_mode(mode)
 
662
        self._end_of_body_handler = self._handle_do_append_end
 
663
    
 
664
    def _handle_do_append_end(self):
 
665
        old_length = self._backing_transport.append_bytes(
 
666
            self._relpath, self._body_bytes, self._mode)
 
667
        self.response = SmartServerResponse(('appended', '%d' % old_length))
 
668
 
 
669
    def do_delete(self, relpath):
 
670
        self._backing_transport.delete(relpath)
 
671
 
 
672
    def do_iter_files_recursive(self, relpath):
 
673
        transport = self._backing_transport.clone(relpath)
 
674
        filenames = transport.iter_files_recursive()
 
675
        return SmartServerResponse(('names',) + tuple(filenames))
 
676
 
 
677
    def do_list_dir(self, relpath):
 
678
        filenames = self._backing_transport.list_dir(relpath)
 
679
        return SmartServerResponse(('names',) + tuple(filenames))
 
680
 
 
681
    def do_mkdir(self, relpath, mode):
 
682
        self._backing_transport.mkdir(relpath,
 
683
                                      self._deserialise_optional_mode(mode))
 
684
 
 
685
    def do_move(self, rel_from, rel_to):
 
686
        self._backing_transport.move(rel_from, rel_to)
 
687
 
 
688
    def do_put(self, relpath, mode):
 
689
        self._converted_command = True
 
690
        self._relpath = relpath
 
691
        self._mode = self._deserialise_optional_mode(mode)
 
692
        self._end_of_body_handler = self._handle_do_put
 
693
 
 
694
    def _handle_do_put(self):
 
695
        self._backing_transport.put_bytes(self._relpath,
 
696
                self._body_bytes, self._mode)
 
697
        self.response = SmartServerResponse(('ok',))
 
698
 
 
699
    def _deserialise_offsets(self, text):
 
700
        # XXX: FIXME this should be on the protocol object.
 
701
        offsets = []
 
702
        for line in text.split('\n'):
 
703
            if not line:
 
704
                continue
 
705
            start, length = line.split(',')
 
706
            offsets.append((int(start), int(length)))
 
707
        return offsets
 
708
 
 
709
    def do_put_non_atomic(self, relpath, mode, create_parent, dir_mode):
 
710
        self._converted_command = True
 
711
        self._end_of_body_handler = self._handle_put_non_atomic
 
712
        self._relpath = relpath
 
713
        self._dir_mode = self._deserialise_optional_mode(dir_mode)
 
714
        self._mode = self._deserialise_optional_mode(mode)
 
715
        # a boolean would be nicer XXX
 
716
        self._create_parent = (create_parent == 'T')
 
717
 
 
718
    def _handle_put_non_atomic(self):
 
719
        self._backing_transport.put_bytes_non_atomic(self._relpath,
 
720
                self._body_bytes,
 
721
                mode=self._mode,
 
722
                create_parent_dir=self._create_parent,
 
723
                dir_mode=self._dir_mode)
 
724
        self.response = SmartServerResponse(('ok',))
 
725
 
 
726
    def do_readv(self, relpath):
 
727
        self._converted_command = True
 
728
        self._end_of_body_handler = self._handle_readv_offsets
 
729
        self._relpath = relpath
 
730
 
 
731
    def end_of_body(self):
 
732
        """No more body data will be received."""
 
733
        self._run_handler_code(self._end_of_body_handler, (), {})
 
734
        # cannot read after this.
 
735
        self.finished_reading = True
 
736
 
 
737
    def _handle_readv_offsets(self):
 
738
        """accept offsets for a readv request."""
 
739
        offsets = self._deserialise_offsets(self._body_bytes)
 
740
        backing_bytes = ''.join(bytes for offset, bytes in
 
741
            self._backing_transport.readv(self._relpath, offsets))
 
742
        self.response = SmartServerResponse(('readv',), backing_bytes)
 
743
        
 
744
    def do_rename(self, rel_from, rel_to):
 
745
        self._backing_transport.rename(rel_from, rel_to)
 
746
 
 
747
    def do_rmdir(self, relpath):
 
748
        self._backing_transport.rmdir(relpath)
 
749
 
 
750
    def do_stat(self, relpath):
 
751
        stat = self._backing_transport.stat(relpath)
 
752
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
 
753
        
 
754
    def do_get_bundle(self, path, revision_id):
 
755
        # open transport relative to our base
 
756
        t = self._backing_transport.clone(path)
 
757
        control, extra_path = bzrdir.BzrDir.open_containing_from_transport(t)
 
758
        repo = control.open_repository()
 
759
        tmpf = tempfile.TemporaryFile()
 
760
        base_revision = revision.NULL_REVISION
 
761
        write_bundle(repo, revision_id, base_revision, tmpf)
 
762
        tmpf.seek(0)
 
763
        return SmartServerResponse((), tmpf.read())
 
764
 
 
765
    def dispatch_command(self, cmd, args):
 
766
        """Deprecated compatibility method.""" # XXX XXX
 
767
        func = getattr(self, 'do_' + cmd, None)
 
768
        if func is None:
 
769
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
 
770
        self._run_handler_code(func, args, {})
 
771
 
 
772
    def _run_handler_code(self, callable, args, kwargs):
 
773
        """Run some handler specific code 'callable'.
 
774
 
 
775
        If a result is returned, it is considered to be the commands response,
 
776
        and finished_reading is set true, and its assigned to self.response.
 
777
 
 
778
        Any exceptions caught are translated and a response object created
 
779
        from them.
 
780
        """
 
781
        result = self._call_converting_errors(callable, args, kwargs)
 
782
        if result is not None:
 
783
            self.response = result
 
784
            self.finished_reading = True
 
785
        # handle unconverted commands
 
786
        if not self._converted_command:
 
787
            self.finished_reading = True
 
788
            if result is None:
 
789
                self.response = SmartServerResponse(('ok',))
 
790
 
 
791
    def _call_converting_errors(self, callable, args, kwargs):
 
792
        """Call callable converting errors to Response objects."""
 
793
        try:
 
794
            return callable(*args, **kwargs)
 
795
        except errors.NoSuchFile, e:
 
796
            return SmartServerResponse(('NoSuchFile', e.path))
 
797
        except errors.FileExists, e:
 
798
            return SmartServerResponse(('FileExists', e.path))
 
799
        except errors.DirectoryNotEmpty, e:
 
800
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
 
801
        except errors.ShortReadvError, e:
 
802
            return SmartServerResponse(('ShortReadvError',
 
803
                e.path, str(e.offset), str(e.length), str(e.actual)))
 
804
        except UnicodeError, e:
 
805
            # If it is a DecodeError, than most likely we are starting
 
806
            # with a plain string
 
807
            str_or_unicode = e.object
 
808
            if isinstance(str_or_unicode, unicode):
 
809
                # XXX: UTF-8 might have \x01 (our seperator byte) in it.  We
 
810
                # should escape it somehow.
 
811
                val = 'u:' + str_or_unicode.encode('utf-8')
 
812
            else:
 
813
                val = 's:' + str_or_unicode.encode('base64')
 
814
            # This handles UnicodeEncodeError or UnicodeDecodeError
 
815
            return SmartServerResponse((e.__class__.__name__,
 
816
                    e.encoding, val, str(e.start), str(e.end), e.reason))
 
817
        except errors.TransportNotPossible, e:
 
818
            if e.msg == "readonly transport":
 
819
                return SmartServerResponse(('ReadOnlyError', ))
 
820
            else:
 
821
                raise
 
822
 
 
823
 
 
824
class SmartTCPServer(object):
 
825
    """Listens on a TCP socket and accepts connections from smart clients
 
826
 
 
827
    hooks: An instance of SmartServerHooks.
 
828
    """
 
829
 
 
830
    def __init__(self, backing_transport, host='127.0.0.1', port=0):
 
831
        """Construct a new server.
 
832
 
 
833
        To actually start it running, call either start_background_thread or
 
834
        serve.
 
835
 
 
836
        :param host: Name of the interface to listen on.
 
837
        :param port: TCP port to listen on, or 0 to allocate a transient port.
 
838
        """
 
839
        # let connections timeout so that we get a chance to terminate
 
840
        # Keep a reference to the exceptions we want to catch because the socket
 
841
        # module's globals get set to None during interpreter shutdown.
 
842
        from socket import timeout as socket_timeout
 
843
        from socket import error as socket_error
 
844
        self._socket_error = socket_error
 
845
        self._socket_timeout = socket_timeout
 
846
        self._server_socket = socket.socket()
 
847
        self._server_socket.bind((host, port))
 
848
        self._sockname = self._server_socket.getsockname()
 
849
        self.port = self._sockname[1]
 
850
        self._server_socket.listen(1)
 
851
        self._server_socket.settimeout(1)
 
852
        self.backing_transport = backing_transport
 
853
        self._started = threading.Event()
 
854
        self._stopped = threading.Event()
 
855
 
 
856
    def serve(self):
 
857
        self._should_terminate = False
 
858
        for hook in SmartTCPServer.hooks['server_started']:
 
859
            hook(self.backing_transport.base, self.get_url())
 
860
        self._started.set()
 
861
        try:
 
862
            try:
 
863
                while not self._should_terminate:
 
864
                    try:
 
865
                        conn, client_addr = self._server_socket.accept()
 
866
                    except self._socket_timeout:
 
867
                        # just check if we're asked to stop
 
868
                        pass
 
869
                    except self._socket_error, e:
 
870
                        # if the socket is closed by stop_background_thread
 
871
                        # we might get a EBADF here, any other socket errors
 
872
                        # should get logged.
 
873
                        if e.args[0] != errno.EBADF:
 
874
                            trace.warning("listening socket error: %s", e)
 
875
                    else:
 
876
                        self.serve_conn(conn)
 
877
            except KeyboardInterrupt:
 
878
                # dont log when CTRL-C'd.
 
879
                raise
 
880
            except Exception, e:
 
881
                trace.error("Unhandled smart server error.")
 
882
                trace.log_exception_quietly()
 
883
                raise
 
884
        finally:
 
885
            self._stopped.set()
 
886
            try:
 
887
                # ensure the server socket is closed.
 
888
                self._server_socket.close()
 
889
            except self._socket_error:
 
890
                # ignore errors on close
 
891
                pass
 
892
            for hook in SmartTCPServer.hooks['server_stopped']:
 
893
                hook(self.backing_transport.base, self.get_url())
 
894
 
 
895
    def get_url(self):
 
896
        """Return the url of the server"""
 
897
        return "bzr://%s:%d/" % self._sockname
 
898
 
 
899
    def serve_conn(self, conn):
 
900
        # For WIN32, where the timeout value from the listening socket
 
901
        # propogates to the newly accepted socket.
 
902
        conn.setblocking(True)
 
903
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
904
        handler = SmartServerSocketStreamMedium(conn, self.backing_transport)
 
905
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
 
906
        connection_thread.setDaemon(True)
 
907
        connection_thread.start()
 
908
 
 
909
    def start_background_thread(self):
 
910
        self._started.clear()
 
911
        self._server_thread = threading.Thread(None,
 
912
                self.serve,
 
913
                name='server-' + self.get_url())
 
914
        self._server_thread.setDaemon(True)
 
915
        self._server_thread.start()
 
916
        self._started.wait()
 
917
 
 
918
    def stop_background_thread(self):
 
919
        self._stopped.clear()
 
920
        # tell the main loop to quit on the next iteration.
 
921
        self._should_terminate = True
 
922
        # close the socket - gives error to connections from here on in,
 
923
        # rather than a connection reset error to connections made during
 
924
        # the period between setting _should_terminate = True and 
 
925
        # the current request completing/aborting. It may also break out the
 
926
        # main loop if it was currently in accept() (on some platforms).
 
927
        try:
 
928
            self._server_socket.close()
 
929
        except self._socket_error:
 
930
            # ignore errors on close
 
931
            pass
 
932
        if not self._stopped.isSet():
 
933
            # server has not stopped (though it may be stopping)
 
934
            # its likely in accept(), so give it a connection
 
935
            temp_socket = socket.socket()
 
936
            temp_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
937
            if not temp_socket.connect_ex(self._sockname):
 
938
                # and close it immediately: we dont choose to send any requests.
 
939
                temp_socket.close()
 
940
        self._stopped.wait()
 
941
        self._server_thread.join()
 
942
 
 
943
 
 
944
class SmartServerHooks(Hooks):
 
945
    """Hooks for the smart server."""
 
946
 
 
947
    def __init__(self):
 
948
        """Create the default hooks.
 
949
 
 
950
        These are all empty initially, because by default nothing should get
 
951
        notified.
 
952
        """
 
953
        Hooks.__init__(self)
 
954
        # Introduced in 0.16:
 
955
        # invoked whenever the server starts serving a directory.
 
956
        # The api signature is (backing url, public url).
 
957
        self['server_started'] = []
 
958
        # Introduced in 0.16:
 
959
        # invoked whenever the server stops serving a directory.
 
960
        # The api signature is (backing url, public url).
 
961
        self['server_stopped'] = []
 
962
 
 
963
SmartTCPServer.hooks = SmartServerHooks()
 
964
 
 
965
 
 
966
class SmartTCPServer_for_testing(SmartTCPServer):
 
967
    """Server suitable for use by transport tests.
 
968
    
 
969
    This server is backed by the process's cwd.
 
970
    """
 
971
 
 
972
    def __init__(self):
 
973
        self._homedir = urlutils.local_path_to_url(os.getcwd())[7:]
 
974
        # The server is set up by default like for ssh access: the client
 
975
        # passes filesystem-absolute paths; therefore the server must look
 
976
        # them up relative to the root directory.  it might be better to act
 
977
        # a public server and have the server rewrite paths into the test
 
978
        # directory.
 
979
        SmartTCPServer.__init__(self,
 
980
            transport.get_transport(urlutils.local_path_to_url('/')))
 
981
        
 
982
    def get_backing_transport(self, backing_transport_server):
 
983
        """Get a backing transport from a server we are decorating."""
 
984
        return transport.get_transport(backing_transport_server.get_url())
 
985
 
 
986
    def setUp(self, backing_transport_server=None):
 
987
        """Set up server for testing"""
 
988
        from bzrlib.transport.chroot import TestingChrootServer
 
989
        if backing_transport_server is None:
 
990
            from bzrlib.transport.local import LocalURLServer
 
991
            backing_transport_server = LocalURLServer()
 
992
        self.chroot_server = TestingChrootServer()
 
993
        self.chroot_server.setUp(backing_transport_server)
 
994
        self.backing_transport = transport.get_transport(
 
995
            self.chroot_server.get_url())
 
996
        self.start_background_thread()
 
997
 
 
998
    def tearDown(self):
 
999
        self.stop_background_thread()
 
1000
 
 
1001
    def get_bogus_url(self):
 
1002
        """Return a URL which will fail to connect"""
 
1003
        return 'bzr://127.0.0.1:1/'
 
1004
 
 
1005
 
 
1006
class SmartStat(object):
 
1007
 
 
1008
    def __init__(self, size, mode):
 
1009
        self.st_size = size
 
1010
        self.st_mode = mode
 
1011
 
 
1012
 
 
1013
class SmartTransport(transport.Transport):
 
1014
    """Connection to a smart server.
 
1015
 
 
1016
    The connection holds references to pipes that can be used to send requests
 
1017
    to the server.
 
1018
 
 
1019
    The connection has a notion of the current directory to which it's
 
1020
    connected; this is incorporated in filenames passed to the server.
 
1021
    
 
1022
    This supports some higher-level RPC operations and can also be treated 
 
1023
    like a Transport to do file-like operations.
 
1024
 
 
1025
    The connection can be made over a tcp socket, or (in future) an ssh pipe
 
1026
    or a series of http requests.  There are concrete subclasses for each
 
1027
    type: SmartTCPTransport, etc.
 
1028
    """
 
1029
 
 
1030
    # IMPORTANT FOR IMPLEMENTORS: SmartTransport MUST NOT be given encoding
 
1031
    # responsibilities: Put those on SmartClient or similar. This is vital for
 
1032
    # the ability to support multiple versions of the smart protocol over time:
 
1033
    # SmartTransport is an adapter from the Transport object model to the 
 
1034
    # SmartClient model, not an encoder.
 
1035
 
 
1036
    def __init__(self, url, clone_from=None, medium=None):
 
1037
        """Constructor.
 
1038
 
 
1039
        :param medium: The medium to use for this RemoteTransport. This must be
 
1040
            supplied if clone_from is None.
 
1041
        """
 
1042
        ### Technically super() here is faulty because Transport's __init__
 
1043
        ### fails to take 2 parameters, and if super were to choose a silly
 
1044
        ### initialisation order things would blow up. 
 
1045
        if not url.endswith('/'):
 
1046
            url += '/'
 
1047
        super(SmartTransport, self).__init__(url)
 
1048
        self._scheme, self._username, self._password, self._host, self._port, self._path = \
 
1049
                transport.split_url(url)
 
1050
        if clone_from is None:
 
1051
            self._medium = medium
 
1052
        else:
 
1053
            # credentials may be stripped from the base in some circumstances
 
1054
            # as yet to be clearly defined or documented, so copy them.
 
1055
            self._username = clone_from._username
 
1056
            # reuse same connection
 
1057
            self._medium = clone_from._medium
 
1058
        assert self._medium is not None
 
1059
 
 
1060
    def abspath(self, relpath):
 
1061
        """Return the full url to the given relative path.
 
1062
        
 
1063
        @param relpath: the relative path or path components
 
1064
        @type relpath: str or list
 
1065
        """
 
1066
        return self._unparse_url(self._remote_path(relpath))
 
1067
    
 
1068
    def clone(self, relative_url):
 
1069
        """Make a new SmartTransport related to me, sharing the same connection.
 
1070
 
 
1071
        This essentially opens a handle on a different remote directory.
 
1072
        """
 
1073
        if relative_url is None:
 
1074
            return SmartTransport(self.base, self)
 
1075
        else:
 
1076
            return SmartTransport(self.abspath(relative_url), self)
 
1077
 
 
1078
    def is_readonly(self):
 
1079
        """Smart server transport can do read/write file operations."""
 
1080
        return False
 
1081
                                                   
 
1082
    def get_smart_client(self):
 
1083
        return self._medium
 
1084
 
 
1085
    def get_smart_medium(self):
 
1086
        return self._medium
 
1087
                                                   
 
1088
    def _unparse_url(self, path):
 
1089
        """Return URL for a path.
 
1090
 
 
1091
        :see: SFTPUrlHandling._unparse_url
 
1092
        """
 
1093
        # TODO: Eventually it should be possible to unify this with
 
1094
        # SFTPUrlHandling._unparse_url?
 
1095
        if path == '':
 
1096
            path = '/'
 
1097
        path = urllib.quote(path)
 
1098
        netloc = urllib.quote(self._host)
 
1099
        if self._username is not None:
 
1100
            netloc = '%s@%s' % (urllib.quote(self._username), netloc)
 
1101
        if self._port is not None:
 
1102
            netloc = '%s:%d' % (netloc, self._port)
 
1103
        return urlparse.urlunparse((self._scheme, netloc, path, '', '', ''))
 
1104
 
 
1105
    def _remote_path(self, relpath):
 
1106
        """Returns the Unicode version of the absolute path for relpath."""
 
1107
        return self._combine_paths(self._path, relpath)
 
1108
 
 
1109
    def _call(self, method, *args):
 
1110
        resp = self._call2(method, *args)
 
1111
        self._translate_error(resp)
 
1112
 
 
1113
    def _call2(self, method, *args):
 
1114
        """Call a method on the remote server."""
 
1115
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1116
        protocol.call(method, *args)
 
1117
        return protocol.read_response_tuple()
 
1118
 
 
1119
    def _call_with_body_bytes(self, method, args, body):
 
1120
        """Call a method on the remote server with body bytes."""
 
1121
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1122
        protocol.call_with_body_bytes((method, ) + args, body)
 
1123
        return protocol.read_response_tuple()
 
1124
 
 
1125
    def has(self, relpath):
 
1126
        """Indicate whether a remote file of the given name exists or not.
 
1127
 
 
1128
        :see: Transport.has()
 
1129
        """
 
1130
        resp = self._call2('has', self._remote_path(relpath))
 
1131
        if resp == ('yes', ):
 
1132
            return True
 
1133
        elif resp == ('no', ):
 
1134
            return False
 
1135
        else:
 
1136
            self._translate_error(resp)
 
1137
 
 
1138
    def get(self, relpath):
 
1139
        """Return file-like object reading the contents of a remote file.
 
1140
        
 
1141
        :see: Transport.get_bytes()/get_file()
 
1142
        """
 
1143
        return StringIO(self.get_bytes(relpath))
 
1144
 
 
1145
    def get_bytes(self, relpath):
 
1146
        remote = self._remote_path(relpath)
 
1147
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1148
        protocol.call('get', remote)
 
1149
        resp = protocol.read_response_tuple(True)
 
1150
        if resp != ('ok', ):
 
1151
            protocol.cancel_read_body()
 
1152
            self._translate_error(resp, relpath)
 
1153
        return protocol.read_body_bytes()
 
1154
 
 
1155
    def _serialise_optional_mode(self, mode):
 
1156
        if mode is None:
 
1157
            return ''
 
1158
        else:
 
1159
            return '%d' % mode
 
1160
 
 
1161
    def mkdir(self, relpath, mode=None):
 
1162
        resp = self._call2('mkdir', self._remote_path(relpath),
 
1163
            self._serialise_optional_mode(mode))
 
1164
        self._translate_error(resp)
 
1165
 
 
1166
    def put_bytes(self, relpath, upload_contents, mode=None):
 
1167
        # FIXME: upload_file is probably not safe for non-ascii characters -
 
1168
        # should probably just pass all parameters as length-delimited
 
1169
        # strings?
 
1170
        resp = self._call_with_body_bytes('put',
 
1171
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
1172
            upload_contents)
 
1173
        self._translate_error(resp)
 
1174
 
 
1175
    def put_bytes_non_atomic(self, relpath, bytes, mode=None,
 
1176
                             create_parent_dir=False,
 
1177
                             dir_mode=None):
 
1178
        """See Transport.put_bytes_non_atomic."""
 
1179
        # FIXME: no encoding in the transport!
 
1180
        create_parent_str = 'F'
 
1181
        if create_parent_dir:
 
1182
            create_parent_str = 'T'
 
1183
 
 
1184
        resp = self._call_with_body_bytes(
 
1185
            'put_non_atomic',
 
1186
            (self._remote_path(relpath), self._serialise_optional_mode(mode),
 
1187
             create_parent_str, self._serialise_optional_mode(dir_mode)),
 
1188
            bytes)
 
1189
        self._translate_error(resp)
 
1190
 
 
1191
    def put_file(self, relpath, upload_file, mode=None):
 
1192
        # its not ideal to seek back, but currently put_non_atomic_file depends
 
1193
        # on transports not reading before failing - which is a faulty
 
1194
        # assumption I think - RBC 20060915
 
1195
        pos = upload_file.tell()
 
1196
        try:
 
1197
            return self.put_bytes(relpath, upload_file.read(), mode)
 
1198
        except:
 
1199
            upload_file.seek(pos)
 
1200
            raise
 
1201
 
 
1202
    def put_file_non_atomic(self, relpath, f, mode=None,
 
1203
                            create_parent_dir=False,
 
1204
                            dir_mode=None):
 
1205
        return self.put_bytes_non_atomic(relpath, f.read(), mode=mode,
 
1206
                                         create_parent_dir=create_parent_dir,
 
1207
                                         dir_mode=dir_mode)
 
1208
 
 
1209
    def append_file(self, relpath, from_file, mode=None):
 
1210
        return self.append_bytes(relpath, from_file.read(), mode)
 
1211
        
 
1212
    def append_bytes(self, relpath, bytes, mode=None):
 
1213
        resp = self._call_with_body_bytes(
 
1214
            'append',
 
1215
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
1216
            bytes)
 
1217
        if resp[0] == 'appended':
 
1218
            return int(resp[1])
 
1219
        self._translate_error(resp)
 
1220
 
 
1221
    def delete(self, relpath):
 
1222
        resp = self._call2('delete', self._remote_path(relpath))
 
1223
        self._translate_error(resp)
 
1224
 
 
1225
    def readv(self, relpath, offsets):
 
1226
        if not offsets:
 
1227
            return
 
1228
 
 
1229
        offsets = list(offsets)
 
1230
 
 
1231
        sorted_offsets = sorted(offsets)
 
1232
        # turn the list of offsets into a stack
 
1233
        offset_stack = iter(offsets)
 
1234
        cur_offset_and_size = offset_stack.next()
 
1235
        coalesced = list(self._coalesce_offsets(sorted_offsets,
 
1236
                               limit=self._max_readv_combine,
 
1237
                               fudge_factor=self._bytes_to_read_before_seek))
 
1238
 
 
1239
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1240
        protocol.call_with_body_readv_array(
 
1241
            ('readv', self._remote_path(relpath)),
 
1242
            [(c.start, c.length) for c in coalesced])
 
1243
        resp = protocol.read_response_tuple(True)
 
1244
 
 
1245
        if resp[0] != 'readv':
 
1246
            # This should raise an exception
 
1247
            protocol.cancel_read_body()
 
1248
            self._translate_error(resp)
 
1249
            return
 
1250
 
 
1251
        # FIXME: this should know how many bytes are needed, for clarity.
 
1252
        data = protocol.read_body_bytes()
 
1253
        # Cache the results, but only until they have been fulfilled
 
1254
        data_map = {}
 
1255
        for c_offset in coalesced:
 
1256
            if len(data) < c_offset.length:
 
1257
                raise errors.ShortReadvError(relpath, c_offset.start,
 
1258
                            c_offset.length, actual=len(data))
 
1259
            for suboffset, subsize in c_offset.ranges:
 
1260
                key = (c_offset.start+suboffset, subsize)
 
1261
                data_map[key] = data[suboffset:suboffset+subsize]
 
1262
            data = data[c_offset.length:]
 
1263
 
 
1264
            # Now that we've read some data, see if we can yield anything back
 
1265
            while cur_offset_and_size in data_map:
 
1266
                this_data = data_map.pop(cur_offset_and_size)
 
1267
                yield cur_offset_and_size[0], this_data
 
1268
                cur_offset_and_size = offset_stack.next()
 
1269
 
 
1270
    def rename(self, rel_from, rel_to):
 
1271
        self._call('rename',
 
1272
                   self._remote_path(rel_from),
 
1273
                   self._remote_path(rel_to))
 
1274
 
 
1275
    def move(self, rel_from, rel_to):
 
1276
        self._call('move',
 
1277
                   self._remote_path(rel_from),
 
1278
                   self._remote_path(rel_to))
 
1279
 
 
1280
    def rmdir(self, relpath):
 
1281
        resp = self._call('rmdir', self._remote_path(relpath))
 
1282
 
 
1283
    def _translate_error(self, resp, orig_path=None):
 
1284
        """Raise an exception from a response"""
 
1285
        if resp is None:
 
1286
            what = None
 
1287
        else:
 
1288
            what = resp[0]
 
1289
        if what == 'ok':
 
1290
            return
 
1291
        elif what == 'NoSuchFile':
 
1292
            if orig_path is not None:
 
1293
                error_path = orig_path
 
1294
            else:
 
1295
                error_path = resp[1]
 
1296
            raise errors.NoSuchFile(error_path)
 
1297
        elif what == 'error':
 
1298
            raise errors.SmartProtocolError(unicode(resp[1]))
 
1299
        elif what == 'FileExists':
 
1300
            raise errors.FileExists(resp[1])
 
1301
        elif what == 'DirectoryNotEmpty':
 
1302
            raise errors.DirectoryNotEmpty(resp[1])
 
1303
        elif what == 'ShortReadvError':
 
1304
            raise errors.ShortReadvError(resp[1], int(resp[2]),
 
1305
                                         int(resp[3]), int(resp[4]))
 
1306
        elif what in ('UnicodeEncodeError', 'UnicodeDecodeError'):
 
1307
            encoding = str(resp[1]) # encoding must always be a string
 
1308
            val = resp[2]
 
1309
            start = int(resp[3])
 
1310
            end = int(resp[4])
 
1311
            reason = str(resp[5]) # reason must always be a string
 
1312
            if val.startswith('u:'):
 
1313
                val = val[2:].decode('utf-8')
 
1314
            elif val.startswith('s:'):
 
1315
                val = val[2:].decode('base64')
 
1316
            if what == 'UnicodeDecodeError':
 
1317
                raise UnicodeDecodeError(encoding, val, start, end, reason)
 
1318
            elif what == 'UnicodeEncodeError':
 
1319
                raise UnicodeEncodeError(encoding, val, start, end, reason)
 
1320
        elif what == "ReadOnlyError":
 
1321
            raise errors.TransportNotPossible('readonly transport')
 
1322
        else:
 
1323
            raise errors.SmartProtocolError('unexpected smart server error: %r' % (resp,))
 
1324
 
 
1325
    def disconnect(self):
 
1326
        self._medium.disconnect()
 
1327
 
 
1328
    def delete_tree(self, relpath):
 
1329
        raise errors.TransportNotPossible('readonly transport')
 
1330
 
 
1331
    def stat(self, relpath):
 
1332
        resp = self._call2('stat', self._remote_path(relpath))
 
1333
        if resp[0] == 'stat':
 
1334
            return SmartStat(int(resp[1]), int(resp[2], 8))
 
1335
        else:
 
1336
            self._translate_error(resp)
 
1337
 
 
1338
    ## def lock_read(self, relpath):
 
1339
    ##     """Lock the given file for shared (read) access.
 
1340
    ##     :return: A lock object, which should be passed to Transport.unlock()
 
1341
    ##     """
 
1342
    ##     # The old RemoteBranch ignore lock for reading, so we will
 
1343
    ##     # continue that tradition and return a bogus lock object.
 
1344
    ##     class BogusLock(object):
 
1345
    ##         def __init__(self, path):
 
1346
    ##             self.path = path
 
1347
    ##         def unlock(self):
 
1348
    ##             pass
 
1349
    ##     return BogusLock(relpath)
 
1350
 
 
1351
    def listable(self):
 
1352
        return True
 
1353
 
 
1354
    def list_dir(self, relpath):
 
1355
        resp = self._call2('list_dir', self._remote_path(relpath))
 
1356
        if resp[0] == 'names':
 
1357
            return [name.encode('ascii') for name in resp[1:]]
 
1358
        else:
 
1359
            self._translate_error(resp)
 
1360
 
 
1361
    def iter_files_recursive(self):
 
1362
        resp = self._call2('iter_files_recursive', self._remote_path(''))
 
1363
        if resp[0] == 'names':
 
1364
            return resp[1:]
 
1365
        else:
 
1366
            self._translate_error(resp)
 
1367
 
 
1368
 
 
1369
class SmartClientMediumRequest(object):
 
1370
    """A request on a SmartClientMedium.
 
1371
 
 
1372
    Each request allows bytes to be provided to it via accept_bytes, and then
 
1373
    the response bytes to be read via read_bytes.
 
1374
 
 
1375
    For instance:
 
1376
    request.accept_bytes('123')
 
1377
    request.finished_writing()
 
1378
    result = request.read_bytes(3)
 
1379
    request.finished_reading()
 
1380
 
 
1381
    It is up to the individual SmartClientMedium whether multiple concurrent
 
1382
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
1383
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
1384
    details on concurrency and pipelining.
 
1385
    """
 
1386
 
 
1387
    def __init__(self, medium):
 
1388
        """Construct a SmartClientMediumRequest for the medium medium."""
 
1389
        self._medium = medium
 
1390
        # we track state by constants - we may want to use the same
 
1391
        # pattern as BodyReader if it gets more complex.
 
1392
        # valid states are: "writing", "reading", "done"
 
1393
        self._state = "writing"
 
1394
 
 
1395
    def accept_bytes(self, bytes):
 
1396
        """Accept bytes for inclusion in this request.
 
1397
 
 
1398
        This method may not be be called after finished_writing() has been
 
1399
        called.  It depends upon the Medium whether or not the bytes will be
 
1400
        immediately transmitted. Message based Mediums will tend to buffer the
 
1401
        bytes until finished_writing() is called.
 
1402
 
 
1403
        :param bytes: A bytestring.
 
1404
        """
 
1405
        if self._state != "writing":
 
1406
            raise errors.WritingCompleted(self)
 
1407
        self._accept_bytes(bytes)
 
1408
 
 
1409
    def _accept_bytes(self, bytes):
 
1410
        """Helper for accept_bytes.
 
1411
 
 
1412
        Accept_bytes checks the state of the request to determing if bytes
 
1413
        should be accepted. After that it hands off to _accept_bytes to do the
 
1414
        actual acceptance.
 
1415
        """
 
1416
        raise NotImplementedError(self._accept_bytes)
 
1417
 
 
1418
    def finished_reading(self):
 
1419
        """Inform the request that all desired data has been read.
 
1420
 
 
1421
        This will remove the request from the pipeline for its medium (if the
 
1422
        medium supports pipelining) and any further calls to methods on the
 
1423
        request will raise ReadingCompleted.
 
1424
        """
 
1425
        if self._state == "writing":
 
1426
            raise errors.WritingNotComplete(self)
 
1427
        if self._state != "reading":
 
1428
            raise errors.ReadingCompleted(self)
 
1429
        self._state = "done"
 
1430
        self._finished_reading()
 
1431
 
 
1432
    def _finished_reading(self):
 
1433
        """Helper for finished_reading.
 
1434
 
 
1435
        finished_reading checks the state of the request to determine if 
 
1436
        finished_reading is allowed, and if it is hands off to _finished_reading
 
1437
        to perform the action.
 
1438
        """
 
1439
        raise NotImplementedError(self._finished_reading)
 
1440
 
 
1441
    def finished_writing(self):
 
1442
        """Finish the writing phase of this request.
 
1443
 
 
1444
        This will flush all pending data for this request along the medium.
 
1445
        After calling finished_writing, you may not call accept_bytes anymore.
 
1446
        """
 
1447
        if self._state != "writing":
 
1448
            raise errors.WritingCompleted(self)
 
1449
        self._state = "reading"
 
1450
        self._finished_writing()
 
1451
 
 
1452
    def _finished_writing(self):
 
1453
        """Helper for finished_writing.
 
1454
 
 
1455
        finished_writing checks the state of the request to determine if 
 
1456
        finished_writing is allowed, and if it is hands off to _finished_writing
 
1457
        to perform the action.
 
1458
        """
 
1459
        raise NotImplementedError(self._finished_writing)
 
1460
 
 
1461
    def read_bytes(self, count):
 
1462
        """Read bytes from this requests response.
 
1463
 
 
1464
        This method will block and wait for count bytes to be read. It may not
 
1465
        be invoked until finished_writing() has been called - this is to ensure
 
1466
        a message-based approach to requests, for compatability with message
 
1467
        based mediums like HTTP.
 
1468
        """
 
1469
        if self._state == "writing":
 
1470
            raise errors.WritingNotComplete(self)
 
1471
        if self._state != "reading":
 
1472
            raise errors.ReadingCompleted(self)
 
1473
        return self._read_bytes(count)
 
1474
 
 
1475
    def _read_bytes(self, count):
 
1476
        """Helper for read_bytes.
 
1477
 
 
1478
        read_bytes checks the state of the request to determing if bytes
 
1479
        should be read. After that it hands off to _read_bytes to do the
 
1480
        actual read.
 
1481
        """
 
1482
        raise NotImplementedError(self._read_bytes)
 
1483
 
 
1484
 
 
1485
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
1486
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
1487
 
 
1488
    def __init__(self, medium):
 
1489
        SmartClientMediumRequest.__init__(self, medium)
 
1490
        # check that we are safe concurrency wise. If some streams start
 
1491
        # allowing concurrent requests - i.e. via multiplexing - then this
 
1492
        # assert should be moved to SmartClientStreamMedium.get_request,
 
1493
        # and the setting/unsetting of _current_request likewise moved into
 
1494
        # that class : but its unneeded overhead for now. RBC 20060922
 
1495
        if self._medium._current_request is not None:
 
1496
            raise errors.TooManyConcurrentRequests(self._medium)
 
1497
        self._medium._current_request = self
 
1498
 
 
1499
    def _accept_bytes(self, bytes):
 
1500
        """See SmartClientMediumRequest._accept_bytes.
 
1501
        
 
1502
        This forwards to self._medium._accept_bytes because we are operating
 
1503
        on the mediums stream.
 
1504
        """
 
1505
        self._medium._accept_bytes(bytes)
 
1506
 
 
1507
    def _finished_reading(self):
 
1508
        """See SmartClientMediumRequest._finished_reading.
 
1509
 
 
1510
        This clears the _current_request on self._medium to allow a new 
 
1511
        request to be created.
 
1512
        """
 
1513
        assert self._medium._current_request is self
 
1514
        self._medium._current_request = None
 
1515
        
 
1516
    def _finished_writing(self):
 
1517
        """See SmartClientMediumRequest._finished_writing.
 
1518
 
 
1519
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
1520
        """
 
1521
        self._medium._flush()
 
1522
 
 
1523
    def _read_bytes(self, count):
 
1524
        """See SmartClientMediumRequest._read_bytes.
 
1525
        
 
1526
        This forwards to self._medium._read_bytes because we are operating
 
1527
        on the mediums stream.
 
1528
        """
 
1529
        return self._medium._read_bytes(count)
 
1530
 
 
1531
 
 
1532
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
1533
    """The client-side protocol for smart version 1."""
 
1534
 
 
1535
    def __init__(self, request):
 
1536
        """Construct a SmartClientRequestProtocolOne.
 
1537
 
 
1538
        :param request: A SmartClientMediumRequest to serialise onto and
 
1539
            deserialise from.
 
1540
        """
 
1541
        self._request = request
 
1542
        self._body_buffer = None
 
1543
 
 
1544
    def call(self, *args):
 
1545
        bytes = _encode_tuple(args)
 
1546
        self._request.accept_bytes(bytes)
 
1547
        self._request.finished_writing()
 
1548
 
 
1549
    def call_with_body_bytes(self, args, body):
 
1550
        """Make a remote call of args with body bytes 'body'.
 
1551
 
 
1552
        After calling this, call read_response_tuple to find the result out.
 
1553
        """
 
1554
        bytes = _encode_tuple(args)
 
1555
        self._request.accept_bytes(bytes)
 
1556
        bytes = self._encode_bulk_data(body)
 
1557
        self._request.accept_bytes(bytes)
 
1558
        self._request.finished_writing()
 
1559
 
 
1560
    def call_with_body_readv_array(self, args, body):
 
1561
        """Make a remote call with a readv array.
 
1562
 
 
1563
        The body is encoded with one line per readv offset pair. The numbers in
 
1564
        each pair are separated by a comma, and no trailing \n is emitted.
 
1565
        """
 
1566
        bytes = _encode_tuple(args)
 
1567
        self._request.accept_bytes(bytes)
 
1568
        readv_bytes = self._serialise_offsets(body)
 
1569
        bytes = self._encode_bulk_data(readv_bytes)
 
1570
        self._request.accept_bytes(bytes)
 
1571
        self._request.finished_writing()
 
1572
 
 
1573
    def cancel_read_body(self):
 
1574
        """After expecting a body, a response code may indicate one otherwise.
 
1575
 
 
1576
        This method lets the domain client inform the protocol that no body
 
1577
        will be transmitted. This is a terminal method: after calling it the
 
1578
        protocol is not able to be used further.
 
1579
        """
 
1580
        self._request.finished_reading()
 
1581
 
 
1582
    def read_response_tuple(self, expect_body=False):
 
1583
        """Read a response tuple from the wire.
 
1584
 
 
1585
        This should only be called once.
 
1586
        """
 
1587
        result = self._recv_tuple()
 
1588
        if not expect_body:
 
1589
            self._request.finished_reading()
 
1590
        return result
 
1591
 
 
1592
    def read_body_bytes(self, count=-1):
 
1593
        """Read bytes from the body, decoding into a byte stream.
 
1594
        
 
1595
        We read all bytes at once to ensure we've checked the trailer for 
 
1596
        errors, and then feed the buffer back as read_body_bytes is called.
 
1597
        """
 
1598
        if self._body_buffer is not None:
 
1599
            return self._body_buffer.read(count)
 
1600
        _body_decoder = LengthPrefixedBodyDecoder()
 
1601
 
 
1602
        while not _body_decoder.finished_reading:
 
1603
            bytes_wanted = _body_decoder.next_read_size()
 
1604
            bytes = self._request.read_bytes(bytes_wanted)
 
1605
            _body_decoder.accept_bytes(bytes)
 
1606
        self._request.finished_reading()
 
1607
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
1608
        # XXX: TODO check the trailer result.
 
1609
        return self._body_buffer.read(count)
 
1610
 
 
1611
    def _recv_tuple(self):
 
1612
        """Receive a tuple from the medium request."""
 
1613
        line = ''
 
1614
        while not line or line[-1] != '\n':
 
1615
            # TODO: this is inefficient - but tuples are short.
 
1616
            new_char = self._request.read_bytes(1)
 
1617
            line += new_char
 
1618
            assert new_char != '', "end of file reading from server."
 
1619
        return _decode_tuple(line)
 
1620
 
 
1621
    def query_version(self):
 
1622
        """Return protocol version number of the server."""
 
1623
        self.call('hello')
 
1624
        resp = self.read_response_tuple()
 
1625
        if resp == ('ok', '1'):
 
1626
            return 1
 
1627
        else:
 
1628
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
1629
 
 
1630
 
 
1631
class SmartClientMedium(object):
 
1632
    """Smart client is a medium for sending smart protocol requests over."""
 
1633
 
 
1634
    def disconnect(self):
 
1635
        """If this medium maintains a persistent connection, close it.
 
1636
        
 
1637
        The default implementation does nothing.
 
1638
        """
 
1639
        
 
1640
 
 
1641
class SmartClientStreamMedium(SmartClientMedium):
 
1642
    """Stream based medium common class.
 
1643
 
 
1644
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
1645
    SmartClientStreamMediumRequest for their requests, and should implement
 
1646
    _accept_bytes and _read_bytes to allow the request objects to send and
 
1647
    receive bytes.
 
1648
    """
 
1649
 
 
1650
    def __init__(self):
 
1651
        self._current_request = None
 
1652
 
 
1653
    def accept_bytes(self, bytes):
 
1654
        self._accept_bytes(bytes)
 
1655
 
 
1656
    def __del__(self):
 
1657
        """The SmartClientStreamMedium knows how to close the stream when it is
 
1658
        finished with it.
 
1659
        """
 
1660
        self.disconnect()
 
1661
 
 
1662
    def _flush(self):
 
1663
        """Flush the output stream.
 
1664
        
 
1665
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
1666
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
1667
        """
 
1668
        raise NotImplementedError(self._flush)
 
1669
 
 
1670
    def get_request(self):
 
1671
        """See SmartClientMedium.get_request().
 
1672
 
 
1673
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
1674
        for get_request.
 
1675
        """
 
1676
        return SmartClientStreamMediumRequest(self)
 
1677
 
 
1678
    def read_bytes(self, count):
 
1679
        return self._read_bytes(count)
 
1680
 
 
1681
 
 
1682
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
1683
    """A client medium using simple pipes.
 
1684
    
 
1685
    This client does not manage the pipes: it assumes they will always be open.
 
1686
    """
 
1687
 
 
1688
    def __init__(self, readable_pipe, writeable_pipe):
 
1689
        SmartClientStreamMedium.__init__(self)
 
1690
        self._readable_pipe = readable_pipe
 
1691
        self._writeable_pipe = writeable_pipe
 
1692
 
 
1693
    def _accept_bytes(self, bytes):
 
1694
        """See SmartClientStreamMedium.accept_bytes."""
 
1695
        self._writeable_pipe.write(bytes)
 
1696
 
 
1697
    def _flush(self):
 
1698
        """See SmartClientStreamMedium._flush()."""
 
1699
        self._writeable_pipe.flush()
 
1700
 
 
1701
    def _read_bytes(self, count):
 
1702
        """See SmartClientStreamMedium._read_bytes."""
 
1703
        return self._readable_pipe.read(count)
 
1704
 
 
1705
 
 
1706
class SmartSSHClientMedium(SmartClientStreamMedium):
 
1707
    """A client medium using SSH."""
 
1708
    
 
1709
    def __init__(self, host, port=None, username=None, password=None,
 
1710
            vendor=None):
 
1711
        """Creates a client that will connect on the first use.
 
1712
        
 
1713
        :param vendor: An optional override for the ssh vendor to use. See
 
1714
            bzrlib.transport.ssh for details on ssh vendors.
 
1715
        """
 
1716
        SmartClientStreamMedium.__init__(self)
 
1717
        self._connected = False
 
1718
        self._host = host
 
1719
        self._password = password
 
1720
        self._port = port
 
1721
        self._username = username
 
1722
        self._read_from = None
 
1723
        self._ssh_connection = None
 
1724
        self._vendor = vendor
 
1725
        self._write_to = None
 
1726
 
 
1727
    def _accept_bytes(self, bytes):
 
1728
        """See SmartClientStreamMedium.accept_bytes."""
 
1729
        self._ensure_connection()
 
1730
        self._write_to.write(bytes)
 
1731
 
 
1732
    def disconnect(self):
 
1733
        """See SmartClientMedium.disconnect()."""
 
1734
        if not self._connected:
 
1735
            return
 
1736
        self._read_from.close()
 
1737
        self._write_to.close()
 
1738
        self._ssh_connection.close()
 
1739
        self._connected = False
 
1740
 
 
1741
    def _ensure_connection(self):
 
1742
        """Connect this medium if not already connected."""
 
1743
        if self._connected:
 
1744
            return
 
1745
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
1746
        if self._vendor is None:
 
1747
            vendor = ssh._get_ssh_vendor()
 
1748
        else:
 
1749
            vendor = self._vendor
 
1750
        self._ssh_connection = vendor.connect_ssh(self._username,
 
1751
                self._password, self._host, self._port,
 
1752
                command=[executable, 'serve', '--inet', '--directory=/',
 
1753
                         '--allow-writes'])
 
1754
        self._read_from, self._write_to = \
 
1755
            self._ssh_connection.get_filelike_channels()
 
1756
        self._connected = True
 
1757
 
 
1758
    def _flush(self):
 
1759
        """See SmartClientStreamMedium._flush()."""
 
1760
        self._write_to.flush()
 
1761
 
 
1762
    def _read_bytes(self, count):
 
1763
        """See SmartClientStreamMedium.read_bytes."""
 
1764
        if not self._connected:
 
1765
            raise errors.MediumNotConnected(self)
 
1766
        return self._read_from.read(count)
 
1767
 
 
1768
 
 
1769
class SmartTCPClientMedium(SmartClientStreamMedium):
 
1770
    """A client medium using TCP."""
 
1771
    
 
1772
    def __init__(self, host, port):
 
1773
        """Creates a client that will connect on the first use."""
 
1774
        SmartClientStreamMedium.__init__(self)
 
1775
        self._connected = False
 
1776
        self._host = host
 
1777
        self._port = port
 
1778
        self._socket = None
 
1779
 
 
1780
    def _accept_bytes(self, bytes):
 
1781
        """See SmartClientMedium.accept_bytes."""
 
1782
        self._ensure_connection()
 
1783
        self._socket.sendall(bytes)
 
1784
 
 
1785
    def disconnect(self):
 
1786
        """See SmartClientMedium.disconnect()."""
 
1787
        if not self._connected:
 
1788
            return
 
1789
        self._socket.close()
 
1790
        self._socket = None
 
1791
        self._connected = False
 
1792
 
 
1793
    def _ensure_connection(self):
 
1794
        """Connect this medium if not already connected."""
 
1795
        if self._connected:
 
1796
            return
 
1797
        self._socket = socket.socket()
 
1798
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
1799
        result = self._socket.connect_ex((self._host, int(self._port)))
 
1800
        if result:
 
1801
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
1802
                    (self._host, self._port, os.strerror(result)))
 
1803
        self._connected = True
 
1804
 
 
1805
    def _flush(self):
 
1806
        """See SmartClientStreamMedium._flush().
 
1807
        
 
1808
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
1809
        add a means to do a flush, but that can be done in the future.
 
1810
        """
 
1811
 
 
1812
    def _read_bytes(self, count):
 
1813
        """See SmartClientMedium.read_bytes."""
 
1814
        if not self._connected:
 
1815
            raise errors.MediumNotConnected(self)
 
1816
        return self._socket.recv(count)
 
1817
 
 
1818
 
 
1819
class SmartTCPTransport(SmartTransport):
 
1820
    """Connection to smart server over plain tcp.
 
1821
    
 
1822
    This is essentially just a factory to get 'RemoteTransport(url,
 
1823
        SmartTCPClientMedium).
 
1824
    """
 
1825
 
 
1826
    def __init__(self, url):
 
1827
        _scheme, _username, _password, _host, _port, _path = \
 
1828
            transport.split_url(url)
 
1829
        if _port is None:
 
1830
            _port = BZR_DEFAULT_PORT
 
1831
        else:
 
1832
            try:
 
1833
                _port = int(_port)
 
1834
            except (ValueError, TypeError), e:
 
1835
                raise errors.InvalidURL(
 
1836
                    path=url, extra="invalid port %s" % _port)
 
1837
        medium = SmartTCPClientMedium(_host, _port)
 
1838
        super(SmartTCPTransport, self).__init__(url, medium=medium)
 
1839
 
 
1840
 
 
1841
class SmartSSHTransport(SmartTransport):
 
1842
    """Connection to smart server over SSH.
 
1843
 
 
1844
    This is essentially just a factory to get 'RemoteTransport(url,
 
1845
        SmartSSHClientMedium).
 
1846
    """
 
1847
 
 
1848
    def __init__(self, url):
 
1849
        _scheme, _username, _password, _host, _port, _path = \
 
1850
            transport.split_url(url)
 
1851
        try:
 
1852
            if _port is not None:
 
1853
                _port = int(_port)
 
1854
        except (ValueError, TypeError), e:
 
1855
            raise errors.InvalidURL(path=url, extra="invalid port %s" % 
 
1856
                _port)
 
1857
        medium = SmartSSHClientMedium(_host, _port, _username, _password)
 
1858
        super(SmartSSHTransport, self).__init__(url, medium=medium)
 
1859
 
 
1860
 
 
1861
class SmartHTTPTransport(SmartTransport):
 
1862
    """Just a way to connect between a bzr+http:// url and http://.
 
1863
    
 
1864
    This connection operates slightly differently than the SmartSSHTransport.
 
1865
    It uses a plain http:// transport underneath, which defines what remote
 
1866
    .bzr/smart URL we are connected to. From there, all paths that are sent are
 
1867
    sent as relative paths, this way, the remote side can properly
 
1868
    de-reference them, since it is likely doing rewrite rules to translate an
 
1869
    HTTP path into a local path.
 
1870
    """
 
1871
 
 
1872
    def __init__(self, url, http_transport=None):
 
1873
        assert url.startswith('bzr+http://')
 
1874
 
 
1875
        if http_transport is None:
 
1876
            http_url = url[len('bzr+'):]
 
1877
            self._http_transport = transport.get_transport(http_url)
 
1878
        else:
 
1879
            self._http_transport = http_transport
 
1880
        http_medium = self._http_transport.get_smart_medium()
 
1881
        super(SmartHTTPTransport, self).__init__(url, medium=http_medium)
 
1882
 
 
1883
    def _remote_path(self, relpath):
 
1884
        """After connecting HTTP Transport only deals in relative URLs."""
 
1885
        if relpath == '.':
 
1886
            return ''
 
1887
        else:
 
1888
            return relpath
 
1889
 
 
1890
    def abspath(self, relpath):
 
1891
        """Return the full url to the given relative path.
 
1892
        
 
1893
        :param relpath: the relative path or path components
 
1894
        :type relpath: str or list
 
1895
        """
 
1896
        return self._unparse_url(self._combine_paths(self._path, relpath))
 
1897
 
 
1898
    def clone(self, relative_url):
 
1899
        """Make a new SmartHTTPTransport related to me.
 
1900
 
 
1901
        This is re-implemented rather than using the default
 
1902
        SmartTransport.clone() because we must be careful about the underlying
 
1903
        http transport.
 
1904
        """
 
1905
        if relative_url:
 
1906
            abs_url = self.abspath(relative_url)
 
1907
        else:
 
1908
            abs_url = self.base
 
1909
        # By cloning the underlying http_transport, we are able to share the
 
1910
        # connection.
 
1911
        new_transport = self._http_transport.clone(relative_url)
 
1912
        return SmartHTTPTransport(abs_url, http_transport=new_transport)
 
1913
 
 
1914
 
 
1915
def get_test_permutations():
 
1916
    """Return (transport, server) permutations for testing."""
 
1917
    ### We may need a little more test framework support to construct an
 
1918
    ### appropriate RemoteTransport in the future.
 
1919
    return [(SmartTCPTransport, SmartTCPServer_for_testing)]