/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_branch.py

  • Committer: John Arbash Meinel
  • Date: 2009-06-18 18:18:36 UTC
  • mto: This revision was merged to the branch mainline in revision 4461.
  • Revision ID: john@arbash-meinel.com-20090618181836-biodfkat9a8eyzjz
The new add_inventory_by_delta is returning a CHKInventory when mapping from NULL
Which is completely valid, but 'broke' one of the tests.
So to fix it, changed the test to use CHKInventories on both sides, and add an __eq__
member. The nice thing is that CHKInventory.__eq__ is fairly cheap, since it only
has to check the root keys.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Tests for the Branch facility that are not interface  tests.
 
18
 
 
19
For interface tests see tests/branch_implementations/*.py.
 
20
 
 
21
For concrete class tests see this file, and for meta-branch tests
 
22
also see this file.
 
23
"""
 
24
 
 
25
from StringIO import StringIO
 
26
 
 
27
from bzrlib import (
 
28
    branch as _mod_branch,
 
29
    bzrdir,
 
30
    config,
 
31
    errors,
 
32
    trace,
 
33
    urlutils,
 
34
    )
 
35
from bzrlib.branch import (
 
36
    Branch,
 
37
    BranchHooks,
 
38
    BranchFormat,
 
39
    BranchReferenceFormat,
 
40
    BzrBranch5,
 
41
    BzrBranchFormat5,
 
42
    BzrBranchFormat6,
 
43
    PullResult,
 
44
    _run_with_write_locked_target,
 
45
    )
 
46
from bzrlib.bzrdir import (BzrDirMetaFormat1, BzrDirMeta1,
 
47
                           BzrDir, BzrDirFormat)
 
48
from bzrlib.errors import (NotBranchError,
 
49
                           UnknownFormatError,
 
50
                           UnknownHook,
 
51
                           UnsupportedFormatError,
 
52
                           )
 
53
 
 
54
from bzrlib.tests import TestCase, TestCaseWithTransport
 
55
from bzrlib.transport import get_transport
 
56
 
 
57
 
 
58
class TestDefaultFormat(TestCase):
 
59
 
 
60
    def test_default_format(self):
 
61
        # update this if you change the default branch format
 
62
        self.assertIsInstance(BranchFormat.get_default_format(),
 
63
                BzrBranchFormat6)
 
64
 
 
65
    def test_default_format_is_same_as_bzrdir_default(self):
 
66
        # XXX: it might be nice if there was only one place the default was
 
67
        # set, but at the moment that's not true -- mbp 20070814 --
 
68
        # https://bugs.launchpad.net/bzr/+bug/132376
 
69
        self.assertEqual(BranchFormat.get_default_format(),
 
70
                BzrDirFormat.get_default_format().get_branch_format())
 
71
 
 
72
    def test_get_set_default_format(self):
 
73
        # set the format and then set it back again
 
74
        old_format = BranchFormat.get_default_format()
 
75
        BranchFormat.set_default_format(SampleBranchFormat())
 
76
        try:
 
77
            # the default branch format is used by the meta dir format
 
78
            # which is not the default bzrdir format at this point
 
79
            dir = BzrDirMetaFormat1().initialize('memory:///')
 
80
            result = dir.create_branch()
 
81
            self.assertEqual(result, 'A branch')
 
82
        finally:
 
83
            BranchFormat.set_default_format(old_format)
 
84
        self.assertEqual(old_format, BranchFormat.get_default_format())
 
85
 
 
86
 
 
87
class TestBranchFormat5(TestCaseWithTransport):
 
88
    """Tests specific to branch format 5"""
 
89
 
 
90
    def test_branch_format_5_uses_lockdir(self):
 
91
        url = self.get_url()
 
92
        bzrdir = BzrDirMetaFormat1().initialize(url)
 
93
        bzrdir.create_repository()
 
94
        branch = bzrdir.create_branch()
 
95
        t = self.get_transport()
 
96
        self.log("branch instance is %r" % branch)
 
97
        self.assert_(isinstance(branch, BzrBranch5))
 
98
        self.assertIsDirectory('.', t)
 
99
        self.assertIsDirectory('.bzr/branch', t)
 
100
        self.assertIsDirectory('.bzr/branch/lock', t)
 
101
        branch.lock_write()
 
102
        try:
 
103
            self.assertIsDirectory('.bzr/branch/lock/held', t)
 
104
        finally:
 
105
            branch.unlock()
 
106
 
 
107
    def test_set_push_location(self):
 
108
        from bzrlib.config import (locations_config_filename,
 
109
                                   ensure_config_dir_exists)
 
110
        ensure_config_dir_exists()
 
111
        fn = locations_config_filename()
 
112
        # write correct newlines to locations.conf
 
113
        # by default ConfigObj uses native line-endings for new files
 
114
        # but uses already existing line-endings if file is not empty
 
115
        f = open(fn, 'wb')
 
116
        try:
 
117
            f.write('# comment\n')
 
118
        finally:
 
119
            f.close()
 
120
 
 
121
        branch = self.make_branch('.', format='knit')
 
122
        branch.set_push_location('foo')
 
123
        local_path = urlutils.local_path_from_url(branch.base[:-1])
 
124
        self.assertFileEqual("# comment\n"
 
125
                             "[%s]\n"
 
126
                             "push_location = foo\n"
 
127
                             "push_location:policy = norecurse\n" % local_path,
 
128
                             fn)
 
129
 
 
130
    # TODO RBC 20051029 test getting a push location from a branch in a
 
131
    # recursive section - that is, it appends the branch name.
 
132
 
 
133
 
 
134
class SampleBranchFormat(BranchFormat):
 
135
    """A sample format
 
136
 
 
137
    this format is initializable, unsupported to aid in testing the
 
138
    open and open_downlevel routines.
 
139
    """
 
140
 
 
141
    def get_format_string(self):
 
142
        """See BzrBranchFormat.get_format_string()."""
 
143
        return "Sample branch format."
 
144
 
 
145
    def initialize(self, a_bzrdir):
 
146
        """Format 4 branches cannot be created."""
 
147
        t = a_bzrdir.get_branch_transport(self)
 
148
        t.put_bytes('format', self.get_format_string())
 
149
        return 'A branch'
 
150
 
 
151
    def is_supported(self):
 
152
        return False
 
153
 
 
154
    def open(self, transport, _found=False, ignore_fallbacks=False):
 
155
        return "opened branch."
 
156
 
 
157
 
 
158
class TestBzrBranchFormat(TestCaseWithTransport):
 
159
    """Tests for the BzrBranchFormat facility."""
 
160
 
 
161
    def test_find_format(self):
 
162
        # is the right format object found for a branch?
 
163
        # create a branch with a few known format objects.
 
164
        # this is not quite the same as
 
165
        self.build_tree(["foo/", "bar/"])
 
166
        def check_format(format, url):
 
167
            dir = format._matchingbzrdir.initialize(url)
 
168
            dir.create_repository()
 
169
            format.initialize(dir)
 
170
            found_format = BranchFormat.find_format(dir)
 
171
            self.failUnless(isinstance(found_format, format.__class__))
 
172
        check_format(BzrBranchFormat5(), "bar")
 
173
 
 
174
    def test_find_format_not_branch(self):
 
175
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
176
        self.assertRaises(NotBranchError,
 
177
                          BranchFormat.find_format,
 
178
                          dir)
 
179
 
 
180
    def test_find_format_unknown_format(self):
 
181
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
182
        SampleBranchFormat().initialize(dir)
 
183
        self.assertRaises(UnknownFormatError,
 
184
                          BranchFormat.find_format,
 
185
                          dir)
 
186
 
 
187
    def test_register_unregister_format(self):
 
188
        format = SampleBranchFormat()
 
189
        # make a control dir
 
190
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
191
        # make a branch
 
192
        format.initialize(dir)
 
193
        # register a format for it.
 
194
        BranchFormat.register_format(format)
 
195
        # which branch.Open will refuse (not supported)
 
196
        self.assertRaises(UnsupportedFormatError, Branch.open, self.get_url())
 
197
        self.make_branch_and_tree('foo')
 
198
        # but open_downlevel will work
 
199
        self.assertEqual(format.open(dir), bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
 
200
        # unregister the format
 
201
        BranchFormat.unregister_format(format)
 
202
        self.make_branch_and_tree('bar')
 
203
 
 
204
 
 
205
class TestBranch67(object):
 
206
    """Common tests for both branch 6 and 7 which are mostly the same."""
 
207
 
 
208
    def get_format_name(self):
 
209
        raise NotImplementedError(self.get_format_name)
 
210
 
 
211
    def get_format_name_subtree(self):
 
212
        raise NotImplementedError(self.get_format_name)
 
213
 
 
214
    def get_class(self):
 
215
        raise NotImplementedError(self.get_class)
 
216
 
 
217
    def test_creation(self):
 
218
        format = BzrDirMetaFormat1()
 
219
        format.set_branch_format(_mod_branch.BzrBranchFormat6())
 
220
        branch = self.make_branch('a', format=format)
 
221
        self.assertIsInstance(branch, self.get_class())
 
222
        branch = self.make_branch('b', format=self.get_format_name())
 
223
        self.assertIsInstance(branch, self.get_class())
 
224
        branch = _mod_branch.Branch.open('a')
 
225
        self.assertIsInstance(branch, self.get_class())
 
226
 
 
227
    def test_layout(self):
 
228
        branch = self.make_branch('a', format=self.get_format_name())
 
229
        self.failUnlessExists('a/.bzr/branch/last-revision')
 
230
        self.failIfExists('a/.bzr/branch/revision-history')
 
231
        self.failIfExists('a/.bzr/branch/references')
 
232
 
 
233
    def test_config(self):
 
234
        """Ensure that all configuration data is stored in the branch"""
 
235
        branch = self.make_branch('a', format=self.get_format_name())
 
236
        branch.set_parent('http://bazaar-vcs.org')
 
237
        self.failIfExists('a/.bzr/branch/parent')
 
238
        self.assertEqual('http://bazaar-vcs.org', branch.get_parent())
 
239
        branch.set_push_location('sftp://bazaar-vcs.org')
 
240
        config = branch.get_config()._get_branch_data_config()
 
241
        self.assertEqual('sftp://bazaar-vcs.org',
 
242
                         config.get_user_option('push_location'))
 
243
        branch.set_bound_location('ftp://bazaar-vcs.org')
 
244
        self.failIfExists('a/.bzr/branch/bound')
 
245
        self.assertEqual('ftp://bazaar-vcs.org', branch.get_bound_location())
 
246
 
 
247
    def test_set_revision_history(self):
 
248
        builder = self.make_branch_builder('.', format=self.get_format_name())
 
249
        builder.build_snapshot('foo', None,
 
250
            [('add', ('', None, 'directory', None))],
 
251
            message='foo')
 
252
        builder.build_snapshot('bar', None, [], message='bar')
 
253
        branch = builder.get_branch()
 
254
        branch.lock_write()
 
255
        self.addCleanup(branch.unlock)
 
256
        branch.set_revision_history(['foo', 'bar'])
 
257
        branch.set_revision_history(['foo'])
 
258
        self.assertRaises(errors.NotLefthandHistory,
 
259
                          branch.set_revision_history, ['bar'])
 
260
 
 
261
    def do_checkout_test(self, lightweight=False):
 
262
        tree = self.make_branch_and_tree('source',
 
263
            format=self.get_format_name_subtree())
 
264
        subtree = self.make_branch_and_tree('source/subtree',
 
265
            format=self.get_format_name_subtree())
 
266
        subsubtree = self.make_branch_and_tree('source/subtree/subsubtree',
 
267
            format=self.get_format_name_subtree())
 
268
        self.build_tree(['source/subtree/file',
 
269
                         'source/subtree/subsubtree/file'])
 
270
        subsubtree.add('file')
 
271
        subtree.add('file')
 
272
        subtree.add_reference(subsubtree)
 
273
        tree.add_reference(subtree)
 
274
        tree.commit('a revision')
 
275
        subtree.commit('a subtree file')
 
276
        subsubtree.commit('a subsubtree file')
 
277
        tree.branch.create_checkout('target', lightweight=lightweight)
 
278
        self.failUnlessExists('target')
 
279
        self.failUnlessExists('target/subtree')
 
280
        self.failUnlessExists('target/subtree/file')
 
281
        self.failUnlessExists('target/subtree/subsubtree/file')
 
282
        subbranch = _mod_branch.Branch.open('target/subtree/subsubtree')
 
283
        if lightweight:
 
284
            self.assertEndsWith(subbranch.base, 'source/subtree/subsubtree/')
 
285
        else:
 
286
            self.assertEndsWith(subbranch.base, 'target/subtree/subsubtree/')
 
287
 
 
288
    def test_checkout_with_references(self):
 
289
        self.do_checkout_test()
 
290
 
 
291
    def test_light_checkout_with_references(self):
 
292
        self.do_checkout_test(lightweight=True)
 
293
 
 
294
    def test_set_push(self):
 
295
        branch = self.make_branch('source', format=self.get_format_name())
 
296
        branch.get_config().set_user_option('push_location', 'old',
 
297
            store=config.STORE_LOCATION)
 
298
        warnings = []
 
299
        def warning(*args):
 
300
            warnings.append(args[0] % args[1:])
 
301
        _warning = trace.warning
 
302
        trace.warning = warning
 
303
        try:
 
304
            branch.set_push_location('new')
 
305
        finally:
 
306
            trace.warning = _warning
 
307
        self.assertEqual(warnings[0], 'Value "new" is masked by "old" from '
 
308
                         'locations.conf')
 
309
 
 
310
 
 
311
class TestBranch6(TestBranch67, TestCaseWithTransport):
 
312
 
 
313
    def get_class(self):
 
314
        return _mod_branch.BzrBranch6
 
315
 
 
316
    def get_format_name(self):
 
317
        return "dirstate-tags"
 
318
 
 
319
    def get_format_name_subtree(self):
 
320
        return "dirstate-with-subtree"
 
321
 
 
322
    def test_set_stacked_on_url_errors(self):
 
323
        branch = self.make_branch('a', format=self.get_format_name())
 
324
        self.assertRaises(errors.UnstackableBranchFormat,
 
325
            branch.set_stacked_on_url, None)
 
326
 
 
327
    def test_default_stacked_location(self):
 
328
        branch = self.make_branch('a', format=self.get_format_name())
 
329
        self.assertRaises(errors.UnstackableBranchFormat, branch.get_stacked_on_url)
 
330
 
 
331
 
 
332
class TestBranch7(TestBranch67, TestCaseWithTransport):
 
333
 
 
334
    def get_class(self):
 
335
        return _mod_branch.BzrBranch7
 
336
 
 
337
    def get_format_name(self):
 
338
        return "1.9"
 
339
 
 
340
    def get_format_name_subtree(self):
 
341
        return "development-subtree"
 
342
 
 
343
    def test_set_stacked_on_url_unstackable_repo(self):
 
344
        repo = self.make_repository('a', format='dirstate-tags')
 
345
        control = repo.bzrdir
 
346
        branch = _mod_branch.BzrBranchFormat7().initialize(control)
 
347
        target = self.make_branch('b')
 
348
        self.assertRaises(errors.UnstackableRepositoryFormat,
 
349
            branch.set_stacked_on_url, target.base)
 
350
 
 
351
    def test_clone_stacked_on_unstackable_repo(self):
 
352
        repo = self.make_repository('a', format='dirstate-tags')
 
353
        control = repo.bzrdir
 
354
        branch = _mod_branch.BzrBranchFormat7().initialize(control)
 
355
        # Calling clone should not raise UnstackableRepositoryFormat.
 
356
        cloned_bzrdir = control.clone('cloned')
 
357
 
 
358
    def _test_default_stacked_location(self):
 
359
        branch = self.make_branch('a', format=self.get_format_name())
 
360
        self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
 
361
 
 
362
    def test_stack_and_unstack(self):
 
363
        branch = self.make_branch('a', format=self.get_format_name())
 
364
        target = self.make_branch_and_tree('b', format=self.get_format_name())
 
365
        branch.set_stacked_on_url(target.branch.base)
 
366
        self.assertEqual(target.branch.base, branch.get_stacked_on_url())
 
367
        revid = target.commit('foo')
 
368
        self.assertTrue(branch.repository.has_revision(revid))
 
369
        branch.set_stacked_on_url(None)
 
370
        self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
 
371
        self.assertFalse(branch.repository.has_revision(revid))
 
372
 
 
373
    def test_open_opens_stacked_reference(self):
 
374
        branch = self.make_branch('a', format=self.get_format_name())
 
375
        target = self.make_branch_and_tree('b', format=self.get_format_name())
 
376
        branch.set_stacked_on_url(target.branch.base)
 
377
        branch = branch.bzrdir.open_branch()
 
378
        revid = target.commit('foo')
 
379
        self.assertTrue(branch.repository.has_revision(revid))
 
380
 
 
381
 
 
382
class BzrBranch8(TestCaseWithTransport):
 
383
 
 
384
    def make_branch(self, location, format=None):
 
385
        if format is None:
 
386
            format = bzrdir.format_registry.make_bzrdir('1.9')
 
387
            format.set_branch_format(_mod_branch.BzrBranchFormat8())
 
388
        return TestCaseWithTransport.make_branch(self, location, format=format)
 
389
 
 
390
    def create_branch_with_reference(self):
 
391
        branch = self.make_branch('branch')
 
392
        branch._set_all_reference_info({'file-id': ('path', 'location')})
 
393
        return branch
 
394
 
 
395
    @staticmethod
 
396
    def instrument_branch(branch, gets):
 
397
        old_get = branch._transport.get
 
398
        def get(*args, **kwargs):
 
399
            gets.append((args, kwargs))
 
400
            return old_get(*args, **kwargs)
 
401
        branch._transport.get = get
 
402
 
 
403
    def test_reference_info_caching_read_locked(self):
 
404
        gets = []
 
405
        branch = self.create_branch_with_reference()
 
406
        branch.lock_read()
 
407
        self.addCleanup(branch.unlock)
 
408
        self.instrument_branch(branch, gets)
 
409
        branch.get_reference_info('file-id')
 
410
        branch.get_reference_info('file-id')
 
411
        self.assertEqual(1, len(gets))
 
412
 
 
413
    def test_reference_info_caching_read_unlocked(self):
 
414
        gets = []
 
415
        branch = self.create_branch_with_reference()
 
416
        self.instrument_branch(branch, gets)
 
417
        branch.get_reference_info('file-id')
 
418
        branch.get_reference_info('file-id')
 
419
        self.assertEqual(2, len(gets))
 
420
 
 
421
    def test_reference_info_caching_write_locked(self):
 
422
        gets = []
 
423
        branch = self.make_branch('branch')
 
424
        branch.lock_write()
 
425
        self.instrument_branch(branch, gets)
 
426
        self.addCleanup(branch.unlock)
 
427
        branch._set_all_reference_info({'file-id': ('path2', 'location2')})
 
428
        path, location = branch.get_reference_info('file-id')
 
429
        self.assertEqual(0, len(gets))
 
430
        self.assertEqual('path2', path)
 
431
        self.assertEqual('location2', location)
 
432
 
 
433
    def test_reference_info_caches_cleared(self):
 
434
        branch = self.make_branch('branch')
 
435
        branch.lock_write()
 
436
        branch.set_reference_info('file-id', 'path2', 'location2')
 
437
        branch.unlock()
 
438
        doppelganger = Branch.open('branch')
 
439
        doppelganger.set_reference_info('file-id', 'path3', 'location3')
 
440
        self.assertEqual(('path3', 'location3'),
 
441
                         branch.get_reference_info('file-id'))
 
442
 
 
443
class TestBranchReference(TestCaseWithTransport):
 
444
    """Tests for the branch reference facility."""
 
445
 
 
446
    def test_create_open_reference(self):
 
447
        bzrdirformat = bzrdir.BzrDirMetaFormat1()
 
448
        t = get_transport(self.get_url('.'))
 
449
        t.mkdir('repo')
 
450
        dir = bzrdirformat.initialize(self.get_url('repo'))
 
451
        dir.create_repository()
 
452
        target_branch = dir.create_branch()
 
453
        t.mkdir('branch')
 
454
        branch_dir = bzrdirformat.initialize(self.get_url('branch'))
 
455
        made_branch = BranchReferenceFormat().initialize(branch_dir, target_branch)
 
456
        self.assertEqual(made_branch.base, target_branch.base)
 
457
        opened_branch = branch_dir.open_branch()
 
458
        self.assertEqual(opened_branch.base, target_branch.base)
 
459
 
 
460
    def test_get_reference(self):
 
461
        """For a BranchReference, get_reference should reutrn the location."""
 
462
        branch = self.make_branch('target')
 
463
        checkout = branch.create_checkout('checkout', lightweight=True)
 
464
        reference_url = branch.bzrdir.root_transport.abspath('') + '/'
 
465
        # if the api for create_checkout changes to return different checkout types
 
466
        # then this file read will fail.
 
467
        self.assertFileEqual(reference_url, 'checkout/.bzr/branch/location')
 
468
        self.assertEqual(reference_url,
 
469
            _mod_branch.BranchReferenceFormat().get_reference(checkout.bzrdir))
 
470
 
 
471
 
 
472
class TestHooks(TestCase):
 
473
 
 
474
    def test_constructor(self):
 
475
        """Check that creating a BranchHooks instance has the right defaults."""
 
476
        hooks = BranchHooks()
 
477
        self.assertTrue("set_rh" in hooks, "set_rh not in %s" % hooks)
 
478
        self.assertTrue("post_push" in hooks, "post_push not in %s" % hooks)
 
479
        self.assertTrue("post_commit" in hooks, "post_commit not in %s" % hooks)
 
480
        self.assertTrue("pre_commit" in hooks, "pre_commit not in %s" % hooks)
 
481
        self.assertTrue("post_pull" in hooks, "post_pull not in %s" % hooks)
 
482
        self.assertTrue("post_uncommit" in hooks, "post_uncommit not in %s" % hooks)
 
483
        self.assertTrue("post_change_branch_tip" in hooks,
 
484
                        "post_change_branch_tip not in %s" % hooks)
 
485
 
 
486
    def test_installed_hooks_are_BranchHooks(self):
 
487
        """The installed hooks object should be a BranchHooks."""
 
488
        # the installed hooks are saved in self._preserved_hooks.
 
489
        self.assertIsInstance(self._preserved_hooks[_mod_branch.Branch][1],
 
490
            BranchHooks)
 
491
 
 
492
 
 
493
class TestPullResult(TestCase):
 
494
 
 
495
    def test_pull_result_to_int(self):
 
496
        # to support old code, the pull result can be used as an int
 
497
        r = PullResult()
 
498
        r.old_revno = 10
 
499
        r.new_revno = 20
 
500
        # this usage of results is not recommended for new code (because it
 
501
        # doesn't describe very well what happened), but for api stability
 
502
        # it's still supported
 
503
        a = "%d revisions pulled" % r
 
504
        self.assertEqual(a, "10 revisions pulled")
 
505
 
 
506
 
 
507
 
 
508
class _StubLockable(object):
 
509
    """Helper for TestRunWithWriteLockedTarget."""
 
510
 
 
511
    def __init__(self, calls, unlock_exc=None):
 
512
        self.calls = calls
 
513
        self.unlock_exc = unlock_exc
 
514
 
 
515
    def lock_write(self):
 
516
        self.calls.append('lock_write')
 
517
 
 
518
    def unlock(self):
 
519
        self.calls.append('unlock')
 
520
        if self.unlock_exc is not None:
 
521
            raise self.unlock_exc
 
522
 
 
523
 
 
524
class _ErrorFromCallable(Exception):
 
525
    """Helper for TestRunWithWriteLockedTarget."""
 
526
 
 
527
 
 
528
class _ErrorFromUnlock(Exception):
 
529
    """Helper for TestRunWithWriteLockedTarget."""
 
530
 
 
531
 
 
532
class TestRunWithWriteLockedTarget(TestCase):
 
533
    """Tests for _run_with_write_locked_target."""
 
534
 
 
535
    def setUp(self):
 
536
        TestCase.setUp(self)
 
537
        self._calls = []
 
538
 
 
539
    def func_that_returns_ok(self):
 
540
        self._calls.append('func called')
 
541
        return 'ok'
 
542
 
 
543
    def func_that_raises(self):
 
544
        self._calls.append('func called')
 
545
        raise _ErrorFromCallable()
 
546
 
 
547
    def test_success_unlocks(self):
 
548
        lockable = _StubLockable(self._calls)
 
549
        result = _run_with_write_locked_target(
 
550
            lockable, self.func_that_returns_ok)
 
551
        self.assertEqual('ok', result)
 
552
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
553
 
 
554
    def test_exception_unlocks_and_propagates(self):
 
555
        lockable = _StubLockable(self._calls)
 
556
        self.assertRaises(_ErrorFromCallable,
 
557
            _run_with_write_locked_target, lockable, self.func_that_raises)
 
558
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
559
 
 
560
    def test_callable_succeeds_but_error_during_unlock(self):
 
561
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
 
562
        self.assertRaises(_ErrorFromUnlock,
 
563
            _run_with_write_locked_target, lockable, self.func_that_returns_ok)
 
564
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
565
 
 
566
    def test_error_during_unlock_does_not_mask_original_error(self):
 
567
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
 
568
        self.assertRaises(_ErrorFromCallable,
 
569
            _run_with_write_locked_target, lockable, self.func_that_raises)
 
570
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
571
 
 
572