/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_branch.py

  • Committer: John Arbash Meinel
  • Date: 2009-12-22 16:28:47 UTC
  • mto: This revision was merged to the branch mainline in revision 4922.
  • Revision ID: john@arbash-meinel.com-20091222162847-tvnsc69to4l4uf5r
Implement a permute_for_extension helper.

Use it for all of the 'simple' extension permutations.
It basically permutes all tests in the current module, by setting TestCase.module.
Which works well for most of our extension tests. Some had more advanced
handling of permutations (extra permutations, custom vars, etc.)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Tests for the Branch facility that are not interface  tests.
 
18
 
 
19
For interface tests see tests/per_branch/*.py.
 
20
 
 
21
For concrete class tests see this file, and for meta-branch tests
 
22
also see this file.
 
23
"""
 
24
 
 
25
from StringIO import StringIO
 
26
 
 
27
from bzrlib import (
 
28
    branch as _mod_branch,
 
29
    bzrdir,
 
30
    config,
 
31
    errors,
 
32
    trace,
 
33
    urlutils,
 
34
    )
 
35
from bzrlib.branch import (
 
36
    Branch,
 
37
    BranchHooks,
 
38
    BranchFormat,
 
39
    BranchReferenceFormat,
 
40
    BzrBranch5,
 
41
    BzrBranchFormat5,
 
42
    BzrBranchFormat6,
 
43
    BzrBranchFormat7,
 
44
    PullResult,
 
45
    _run_with_write_locked_target,
 
46
    )
 
47
from bzrlib.bzrdir import (BzrDirMetaFormat1, BzrDirMeta1,
 
48
                           BzrDir, BzrDirFormat)
 
49
from bzrlib.errors import (NotBranchError,
 
50
                           UnknownFormatError,
 
51
                           UnknownHook,
 
52
                           UnsupportedFormatError,
 
53
                           )
 
54
 
 
55
from bzrlib.tests import TestCase, TestCaseWithTransport
 
56
from bzrlib.transport import get_transport
 
57
 
 
58
 
 
59
class TestDefaultFormat(TestCase):
 
60
 
 
61
    def test_default_format(self):
 
62
        # update this if you change the default branch format
 
63
        self.assertIsInstance(BranchFormat.get_default_format(),
 
64
                BzrBranchFormat7)
 
65
 
 
66
    def test_default_format_is_same_as_bzrdir_default(self):
 
67
        # XXX: it might be nice if there was only one place the default was
 
68
        # set, but at the moment that's not true -- mbp 20070814 --
 
69
        # https://bugs.launchpad.net/bzr/+bug/132376
 
70
        self.assertEqual(BranchFormat.get_default_format(),
 
71
                BzrDirFormat.get_default_format().get_branch_format())
 
72
 
 
73
    def test_get_set_default_format(self):
 
74
        # set the format and then set it back again
 
75
        old_format = BranchFormat.get_default_format()
 
76
        BranchFormat.set_default_format(SampleBranchFormat())
 
77
        try:
 
78
            # the default branch format is used by the meta dir format
 
79
            # which is not the default bzrdir format at this point
 
80
            dir = BzrDirMetaFormat1().initialize('memory:///')
 
81
            result = dir.create_branch()
 
82
            self.assertEqual(result, 'A branch')
 
83
        finally:
 
84
            BranchFormat.set_default_format(old_format)
 
85
        self.assertEqual(old_format, BranchFormat.get_default_format())
 
86
 
 
87
 
 
88
class TestBranchFormat5(TestCaseWithTransport):
 
89
    """Tests specific to branch format 5"""
 
90
 
 
91
    def test_branch_format_5_uses_lockdir(self):
 
92
        url = self.get_url()
 
93
        bzrdir = BzrDirMetaFormat1().initialize(url)
 
94
        bzrdir.create_repository()
 
95
        branch = bzrdir.create_branch()
 
96
        t = self.get_transport()
 
97
        self.log("branch instance is %r" % branch)
 
98
        self.assert_(isinstance(branch, BzrBranch5))
 
99
        self.assertIsDirectory('.', t)
 
100
        self.assertIsDirectory('.bzr/branch', t)
 
101
        self.assertIsDirectory('.bzr/branch/lock', t)
 
102
        branch.lock_write()
 
103
        try:
 
104
            self.assertIsDirectory('.bzr/branch/lock/held', t)
 
105
        finally:
 
106
            branch.unlock()
 
107
 
 
108
    def test_set_push_location(self):
 
109
        from bzrlib.config import (locations_config_filename,
 
110
                                   ensure_config_dir_exists)
 
111
        ensure_config_dir_exists()
 
112
        fn = locations_config_filename()
 
113
        # write correct newlines to locations.conf
 
114
        # by default ConfigObj uses native line-endings for new files
 
115
        # but uses already existing line-endings if file is not empty
 
116
        f = open(fn, 'wb')
 
117
        try:
 
118
            f.write('# comment\n')
 
119
        finally:
 
120
            f.close()
 
121
 
 
122
        branch = self.make_branch('.', format='knit')
 
123
        branch.set_push_location('foo')
 
124
        local_path = urlutils.local_path_from_url(branch.base[:-1])
 
125
        self.assertFileEqual("# comment\n"
 
126
                             "[%s]\n"
 
127
                             "push_location = foo\n"
 
128
                             "push_location:policy = norecurse\n" % local_path,
 
129
                             fn)
 
130
 
 
131
    # TODO RBC 20051029 test getting a push location from a branch in a
 
132
    # recursive section - that is, it appends the branch name.
 
133
 
 
134
 
 
135
class SampleBranchFormat(BranchFormat):
 
136
    """A sample format
 
137
 
 
138
    this format is initializable, unsupported to aid in testing the
 
139
    open and open_downlevel routines.
 
140
    """
 
141
 
 
142
    def get_format_string(self):
 
143
        """See BzrBranchFormat.get_format_string()."""
 
144
        return "Sample branch format."
 
145
 
 
146
    def initialize(self, a_bzrdir):
 
147
        """Format 4 branches cannot be created."""
 
148
        t = a_bzrdir.get_branch_transport(self)
 
149
        t.put_bytes('format', self.get_format_string())
 
150
        return 'A branch'
 
151
 
 
152
    def is_supported(self):
 
153
        return False
 
154
 
 
155
    def open(self, transport, _found=False, ignore_fallbacks=False):
 
156
        return "opened branch."
 
157
 
 
158
 
 
159
class TestBzrBranchFormat(TestCaseWithTransport):
 
160
    """Tests for the BzrBranchFormat facility."""
 
161
 
 
162
    def test_find_format(self):
 
163
        # is the right format object found for a branch?
 
164
        # create a branch with a few known format objects.
 
165
        # this is not quite the same as
 
166
        self.build_tree(["foo/", "bar/"])
 
167
        def check_format(format, url):
 
168
            dir = format._matchingbzrdir.initialize(url)
 
169
            dir.create_repository()
 
170
            format.initialize(dir)
 
171
            found_format = BranchFormat.find_format(dir)
 
172
            self.failUnless(isinstance(found_format, format.__class__))
 
173
        check_format(BzrBranchFormat5(), "bar")
 
174
 
 
175
    def test_find_format_not_branch(self):
 
176
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
177
        self.assertRaises(NotBranchError,
 
178
                          BranchFormat.find_format,
 
179
                          dir)
 
180
 
 
181
    def test_find_format_unknown_format(self):
 
182
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
183
        SampleBranchFormat().initialize(dir)
 
184
        self.assertRaises(UnknownFormatError,
 
185
                          BranchFormat.find_format,
 
186
                          dir)
 
187
 
 
188
    def test_register_unregister_format(self):
 
189
        format = SampleBranchFormat()
 
190
        # make a control dir
 
191
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
192
        # make a branch
 
193
        format.initialize(dir)
 
194
        # register a format for it.
 
195
        BranchFormat.register_format(format)
 
196
        # which branch.Open will refuse (not supported)
 
197
        self.assertRaises(UnsupportedFormatError, Branch.open, self.get_url())
 
198
        self.make_branch_and_tree('foo')
 
199
        # but open_downlevel will work
 
200
        self.assertEqual(format.open(dir), bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
 
201
        # unregister the format
 
202
        BranchFormat.unregister_format(format)
 
203
        self.make_branch_and_tree('bar')
 
204
 
 
205
 
 
206
class TestBranch67(object):
 
207
    """Common tests for both branch 6 and 7 which are mostly the same."""
 
208
 
 
209
    def get_format_name(self):
 
210
        raise NotImplementedError(self.get_format_name)
 
211
 
 
212
    def get_format_name_subtree(self):
 
213
        raise NotImplementedError(self.get_format_name)
 
214
 
 
215
    def get_class(self):
 
216
        raise NotImplementedError(self.get_class)
 
217
 
 
218
    def test_creation(self):
 
219
        format = BzrDirMetaFormat1()
 
220
        format.set_branch_format(_mod_branch.BzrBranchFormat6())
 
221
        branch = self.make_branch('a', format=format)
 
222
        self.assertIsInstance(branch, self.get_class())
 
223
        branch = self.make_branch('b', format=self.get_format_name())
 
224
        self.assertIsInstance(branch, self.get_class())
 
225
        branch = _mod_branch.Branch.open('a')
 
226
        self.assertIsInstance(branch, self.get_class())
 
227
 
 
228
    def test_layout(self):
 
229
        branch = self.make_branch('a', format=self.get_format_name())
 
230
        self.failUnlessExists('a/.bzr/branch/last-revision')
 
231
        self.failIfExists('a/.bzr/branch/revision-history')
 
232
        self.failIfExists('a/.bzr/branch/references')
 
233
 
 
234
    def test_config(self):
 
235
        """Ensure that all configuration data is stored in the branch"""
 
236
        branch = self.make_branch('a', format=self.get_format_name())
 
237
        branch.set_parent('http://bazaar-vcs.org')
 
238
        self.failIfExists('a/.bzr/branch/parent')
 
239
        self.assertEqual('http://bazaar-vcs.org', branch.get_parent())
 
240
        branch.set_push_location('sftp://bazaar-vcs.org')
 
241
        config = branch.get_config()._get_branch_data_config()
 
242
        self.assertEqual('sftp://bazaar-vcs.org',
 
243
                         config.get_user_option('push_location'))
 
244
        branch.set_bound_location('ftp://bazaar-vcs.org')
 
245
        self.failIfExists('a/.bzr/branch/bound')
 
246
        self.assertEqual('ftp://bazaar-vcs.org', branch.get_bound_location())
 
247
 
 
248
    def test_set_revision_history(self):
 
249
        builder = self.make_branch_builder('.', format=self.get_format_name())
 
250
        builder.build_snapshot('foo', None,
 
251
            [('add', ('', None, 'directory', None))],
 
252
            message='foo')
 
253
        builder.build_snapshot('bar', None, [], message='bar')
 
254
        branch = builder.get_branch()
 
255
        branch.lock_write()
 
256
        self.addCleanup(branch.unlock)
 
257
        branch.set_revision_history(['foo', 'bar'])
 
258
        branch.set_revision_history(['foo'])
 
259
        self.assertRaises(errors.NotLefthandHistory,
 
260
                          branch.set_revision_history, ['bar'])
 
261
 
 
262
    def do_checkout_test(self, lightweight=False):
 
263
        tree = self.make_branch_and_tree('source',
 
264
            format=self.get_format_name_subtree())
 
265
        subtree = self.make_branch_and_tree('source/subtree',
 
266
            format=self.get_format_name_subtree())
 
267
        subsubtree = self.make_branch_and_tree('source/subtree/subsubtree',
 
268
            format=self.get_format_name_subtree())
 
269
        self.build_tree(['source/subtree/file',
 
270
                         'source/subtree/subsubtree/file'])
 
271
        subsubtree.add('file')
 
272
        subtree.add('file')
 
273
        subtree.add_reference(subsubtree)
 
274
        tree.add_reference(subtree)
 
275
        tree.commit('a revision')
 
276
        subtree.commit('a subtree file')
 
277
        subsubtree.commit('a subsubtree file')
 
278
        tree.branch.create_checkout('target', lightweight=lightweight)
 
279
        self.failUnlessExists('target')
 
280
        self.failUnlessExists('target/subtree')
 
281
        self.failUnlessExists('target/subtree/file')
 
282
        self.failUnlessExists('target/subtree/subsubtree/file')
 
283
        subbranch = _mod_branch.Branch.open('target/subtree/subsubtree')
 
284
        if lightweight:
 
285
            self.assertEndsWith(subbranch.base, 'source/subtree/subsubtree/')
 
286
        else:
 
287
            self.assertEndsWith(subbranch.base, 'target/subtree/subsubtree/')
 
288
 
 
289
    def test_checkout_with_references(self):
 
290
        self.do_checkout_test()
 
291
 
 
292
    def test_light_checkout_with_references(self):
 
293
        self.do_checkout_test(lightweight=True)
 
294
 
 
295
    def test_set_push(self):
 
296
        branch = self.make_branch('source', format=self.get_format_name())
 
297
        branch.get_config().set_user_option('push_location', 'old',
 
298
            store=config.STORE_LOCATION)
 
299
        warnings = []
 
300
        def warning(*args):
 
301
            warnings.append(args[0] % args[1:])
 
302
        _warning = trace.warning
 
303
        trace.warning = warning
 
304
        try:
 
305
            branch.set_push_location('new')
 
306
        finally:
 
307
            trace.warning = _warning
 
308
        self.assertEqual(warnings[0], 'Value "new" is masked by "old" from '
 
309
                         'locations.conf')
 
310
 
 
311
 
 
312
class TestBranch6(TestBranch67, TestCaseWithTransport):
 
313
 
 
314
    def get_class(self):
 
315
        return _mod_branch.BzrBranch6
 
316
 
 
317
    def get_format_name(self):
 
318
        return "dirstate-tags"
 
319
 
 
320
    def get_format_name_subtree(self):
 
321
        return "dirstate-with-subtree"
 
322
 
 
323
    def test_set_stacked_on_url_errors(self):
 
324
        branch = self.make_branch('a', format=self.get_format_name())
 
325
        self.assertRaises(errors.UnstackableBranchFormat,
 
326
            branch.set_stacked_on_url, None)
 
327
 
 
328
    def test_default_stacked_location(self):
 
329
        branch = self.make_branch('a', format=self.get_format_name())
 
330
        self.assertRaises(errors.UnstackableBranchFormat, branch.get_stacked_on_url)
 
331
 
 
332
 
 
333
class TestBranch7(TestBranch67, TestCaseWithTransport):
 
334
 
 
335
    def get_class(self):
 
336
        return _mod_branch.BzrBranch7
 
337
 
 
338
    def get_format_name(self):
 
339
        return "1.9"
 
340
 
 
341
    def get_format_name_subtree(self):
 
342
        return "development-subtree"
 
343
 
 
344
    def test_set_stacked_on_url_unstackable_repo(self):
 
345
        repo = self.make_repository('a', format='dirstate-tags')
 
346
        control = repo.bzrdir
 
347
        branch = _mod_branch.BzrBranchFormat7().initialize(control)
 
348
        target = self.make_branch('b')
 
349
        self.assertRaises(errors.UnstackableRepositoryFormat,
 
350
            branch.set_stacked_on_url, target.base)
 
351
 
 
352
    def test_clone_stacked_on_unstackable_repo(self):
 
353
        repo = self.make_repository('a', format='dirstate-tags')
 
354
        control = repo.bzrdir
 
355
        branch = _mod_branch.BzrBranchFormat7().initialize(control)
 
356
        # Calling clone should not raise UnstackableRepositoryFormat.
 
357
        cloned_bzrdir = control.clone('cloned')
 
358
 
 
359
    def _test_default_stacked_location(self):
 
360
        branch = self.make_branch('a', format=self.get_format_name())
 
361
        self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
 
362
 
 
363
    def test_stack_and_unstack(self):
 
364
        branch = self.make_branch('a', format=self.get_format_name())
 
365
        target = self.make_branch_and_tree('b', format=self.get_format_name())
 
366
        branch.set_stacked_on_url(target.branch.base)
 
367
        self.assertEqual(target.branch.base, branch.get_stacked_on_url())
 
368
        revid = target.commit('foo')
 
369
        self.assertTrue(branch.repository.has_revision(revid))
 
370
        branch.set_stacked_on_url(None)
 
371
        self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
 
372
        self.assertFalse(branch.repository.has_revision(revid))
 
373
 
 
374
    def test_open_opens_stacked_reference(self):
 
375
        branch = self.make_branch('a', format=self.get_format_name())
 
376
        target = self.make_branch_and_tree('b', format=self.get_format_name())
 
377
        branch.set_stacked_on_url(target.branch.base)
 
378
        branch = branch.bzrdir.open_branch()
 
379
        revid = target.commit('foo')
 
380
        self.assertTrue(branch.repository.has_revision(revid))
 
381
 
 
382
 
 
383
class BzrBranch8(TestCaseWithTransport):
 
384
 
 
385
    def make_branch(self, location, format=None):
 
386
        if format is None:
 
387
            format = bzrdir.format_registry.make_bzrdir('1.9')
 
388
            format.set_branch_format(_mod_branch.BzrBranchFormat8())
 
389
        return TestCaseWithTransport.make_branch(self, location, format=format)
 
390
 
 
391
    def create_branch_with_reference(self):
 
392
        branch = self.make_branch('branch')
 
393
        branch._set_all_reference_info({'file-id': ('path', 'location')})
 
394
        return branch
 
395
 
 
396
    @staticmethod
 
397
    def instrument_branch(branch, gets):
 
398
        old_get = branch._transport.get
 
399
        def get(*args, **kwargs):
 
400
            gets.append((args, kwargs))
 
401
            return old_get(*args, **kwargs)
 
402
        branch._transport.get = get
 
403
 
 
404
    def test_reference_info_caching_read_locked(self):
 
405
        gets = []
 
406
        branch = self.create_branch_with_reference()
 
407
        branch.lock_read()
 
408
        self.addCleanup(branch.unlock)
 
409
        self.instrument_branch(branch, gets)
 
410
        branch.get_reference_info('file-id')
 
411
        branch.get_reference_info('file-id')
 
412
        self.assertEqual(1, len(gets))
 
413
 
 
414
    def test_reference_info_caching_read_unlocked(self):
 
415
        gets = []
 
416
        branch = self.create_branch_with_reference()
 
417
        self.instrument_branch(branch, gets)
 
418
        branch.get_reference_info('file-id')
 
419
        branch.get_reference_info('file-id')
 
420
        self.assertEqual(2, len(gets))
 
421
 
 
422
    def test_reference_info_caching_write_locked(self):
 
423
        gets = []
 
424
        branch = self.make_branch('branch')
 
425
        branch.lock_write()
 
426
        self.instrument_branch(branch, gets)
 
427
        self.addCleanup(branch.unlock)
 
428
        branch._set_all_reference_info({'file-id': ('path2', 'location2')})
 
429
        path, location = branch.get_reference_info('file-id')
 
430
        self.assertEqual(0, len(gets))
 
431
        self.assertEqual('path2', path)
 
432
        self.assertEqual('location2', location)
 
433
 
 
434
    def test_reference_info_caches_cleared(self):
 
435
        branch = self.make_branch('branch')
 
436
        branch.lock_write()
 
437
        branch.set_reference_info('file-id', 'path2', 'location2')
 
438
        branch.unlock()
 
439
        doppelganger = Branch.open('branch')
 
440
        doppelganger.set_reference_info('file-id', 'path3', 'location3')
 
441
        self.assertEqual(('path3', 'location3'),
 
442
                         branch.get_reference_info('file-id'))
 
443
 
 
444
class TestBranchReference(TestCaseWithTransport):
 
445
    """Tests for the branch reference facility."""
 
446
 
 
447
    def test_create_open_reference(self):
 
448
        bzrdirformat = bzrdir.BzrDirMetaFormat1()
 
449
        t = get_transport(self.get_url('.'))
 
450
        t.mkdir('repo')
 
451
        dir = bzrdirformat.initialize(self.get_url('repo'))
 
452
        dir.create_repository()
 
453
        target_branch = dir.create_branch()
 
454
        t.mkdir('branch')
 
455
        branch_dir = bzrdirformat.initialize(self.get_url('branch'))
 
456
        made_branch = BranchReferenceFormat().initialize(branch_dir, target_branch)
 
457
        self.assertEqual(made_branch.base, target_branch.base)
 
458
        opened_branch = branch_dir.open_branch()
 
459
        self.assertEqual(opened_branch.base, target_branch.base)
 
460
 
 
461
    def test_get_reference(self):
 
462
        """For a BranchReference, get_reference should reutrn the location."""
 
463
        branch = self.make_branch('target')
 
464
        checkout = branch.create_checkout('checkout', lightweight=True)
 
465
        reference_url = branch.bzrdir.root_transport.abspath('') + '/'
 
466
        # if the api for create_checkout changes to return different checkout types
 
467
        # then this file read will fail.
 
468
        self.assertFileEqual(reference_url, 'checkout/.bzr/branch/location')
 
469
        self.assertEqual(reference_url,
 
470
            _mod_branch.BranchReferenceFormat().get_reference(checkout.bzrdir))
 
471
 
 
472
 
 
473
class TestHooks(TestCase):
 
474
 
 
475
    def test_constructor(self):
 
476
        """Check that creating a BranchHooks instance has the right defaults."""
 
477
        hooks = BranchHooks()
 
478
        self.assertTrue("set_rh" in hooks, "set_rh not in %s" % hooks)
 
479
        self.assertTrue("post_push" in hooks, "post_push not in %s" % hooks)
 
480
        self.assertTrue("post_commit" in hooks, "post_commit not in %s" % hooks)
 
481
        self.assertTrue("pre_commit" in hooks, "pre_commit not in %s" % hooks)
 
482
        self.assertTrue("post_pull" in hooks, "post_pull not in %s" % hooks)
 
483
        self.assertTrue("post_uncommit" in hooks, "post_uncommit not in %s" % hooks)
 
484
        self.assertTrue("post_change_branch_tip" in hooks,
 
485
                        "post_change_branch_tip not in %s" % hooks)
 
486
 
 
487
    def test_installed_hooks_are_BranchHooks(self):
 
488
        """The installed hooks object should be a BranchHooks."""
 
489
        # the installed hooks are saved in self._preserved_hooks.
 
490
        self.assertIsInstance(self._preserved_hooks[_mod_branch.Branch][1],
 
491
            BranchHooks)
 
492
 
 
493
 
 
494
class TestPullResult(TestCase):
 
495
 
 
496
    def test_pull_result_to_int(self):
 
497
        # to support old code, the pull result can be used as an int
 
498
        r = PullResult()
 
499
        r.old_revno = 10
 
500
        r.new_revno = 20
 
501
        # this usage of results is not recommended for new code (because it
 
502
        # doesn't describe very well what happened), but for api stability
 
503
        # it's still supported
 
504
        a = "%d revisions pulled" % r
 
505
        self.assertEqual(a, "10 revisions pulled")
 
506
 
 
507
    def test_report_changed(self):
 
508
        r = PullResult()
 
509
        r.old_revid = "old-revid"
 
510
        r.old_revno = 10
 
511
        r.new_revid = "new-revid"
 
512
        r.new_revno = 20
 
513
        f = StringIO()
 
514
        r.report(f)
 
515
        self.assertEqual("Now on revision 20.\n", f.getvalue())
 
516
 
 
517
    def test_report_unchanged(self):
 
518
        r = PullResult()
 
519
        r.old_revid = "same-revid"
 
520
        r.new_revid = "same-revid"
 
521
        f = StringIO()
 
522
        r.report(f)
 
523
        self.assertEqual("No revisions to pull.\n", f.getvalue())
 
524
 
 
525
 
 
526
class _StubLockable(object):
 
527
    """Helper for TestRunWithWriteLockedTarget."""
 
528
 
 
529
    def __init__(self, calls, unlock_exc=None):
 
530
        self.calls = calls
 
531
        self.unlock_exc = unlock_exc
 
532
 
 
533
    def lock_write(self):
 
534
        self.calls.append('lock_write')
 
535
 
 
536
    def unlock(self):
 
537
        self.calls.append('unlock')
 
538
        if self.unlock_exc is not None:
 
539
            raise self.unlock_exc
 
540
 
 
541
 
 
542
class _ErrorFromCallable(Exception):
 
543
    """Helper for TestRunWithWriteLockedTarget."""
 
544
 
 
545
 
 
546
class _ErrorFromUnlock(Exception):
 
547
    """Helper for TestRunWithWriteLockedTarget."""
 
548
 
 
549
 
 
550
class TestRunWithWriteLockedTarget(TestCase):
 
551
    """Tests for _run_with_write_locked_target."""
 
552
 
 
553
    def setUp(self):
 
554
        TestCase.setUp(self)
 
555
        self._calls = []
 
556
 
 
557
    def func_that_returns_ok(self):
 
558
        self._calls.append('func called')
 
559
        return 'ok'
 
560
 
 
561
    def func_that_raises(self):
 
562
        self._calls.append('func called')
 
563
        raise _ErrorFromCallable()
 
564
 
 
565
    def test_success_unlocks(self):
 
566
        lockable = _StubLockable(self._calls)
 
567
        result = _run_with_write_locked_target(
 
568
            lockable, self.func_that_returns_ok)
 
569
        self.assertEqual('ok', result)
 
570
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
571
 
 
572
    def test_exception_unlocks_and_propagates(self):
 
573
        lockable = _StubLockable(self._calls)
 
574
        self.assertRaises(_ErrorFromCallable,
 
575
            _run_with_write_locked_target, lockable, self.func_that_raises)
 
576
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
577
 
 
578
    def test_callable_succeeds_but_error_during_unlock(self):
 
579
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
 
580
        self.assertRaises(_ErrorFromUnlock,
 
581
            _run_with_write_locked_target, lockable, self.func_that_returns_ok)
 
582
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
583
 
 
584
    def test_error_during_unlock_does_not_mask_original_error(self):
 
585
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
 
586
        self.assertRaises(_ErrorFromCallable,
 
587
            _run_with_write_locked_target, lockable, self.func_that_raises)
 
588
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
589
 
 
590