Largest number formed from some or all elements in an array divisible by 3

I am trying solve the following question:

Given an array with a length of 1-9 elements consisting of digits 0-9, what is the largest number divisble by 3 that can be formed using some / all the elements in the array?

The question only accepts Java and Python, and I chose Python despite being completely inexperienced with it.

I looked around and it seems like the general idea was to kick off the smallest element to get the digits total divisible by 3, and wrote the following "subtractive" approach:

def maxdiv3(l):
    l.sort()

    tot = 0
    for i in l:
        tot += i

    if tot % 3 == 0:
        l.sort(reverse=True)
        return int(''.join(str(e) for e in l))
    elif tot % 3 == 1:
        cl = [] # A copy of the list but only for elements % 3 != 0
        acl = [] # Anti copy of the list, only for elements % 3 = 0
        for i in l:
            if i % 3 == 0:
                acl.append(i)
            else:
                cl.append(i)

        removed = False
        nl = [] # A new list for the final results
        for i in cl:
            if not removed:
                if i % 3 == 1:
                    removed = True
                else:
                    nl.append(i)
            else:
                nl.append(i)

        if removed:
            nl.extend(acl)
            nl.sort(reverse=True)
            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))
            else:
                return 0
        else:
            if len(acl) > 0:
                acl.sort(reverse=True)
                return int(''.join(str(e) for e in acl))

        return 0
    elif tot % 3 == 2:
        cl = []
        acl = []
        for i in l:
            if i % 3 == 0:
                acl.append(i)
            else:
                cl.append(i)

        removed2 = False
        nl = []
        for i in cl:
            if not removed2:
                if i % 3 == 2:
                    removed2 = True
                else:
                    nl.append(i)
            else:
                nl.append(i)

        if removed2:
            nl.extend(acl)
            nl.sort(reverse=True)
            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))

        removed1 = 0
        nl = []
        for i in cl:
            if removed1 < 2:
                if i % 3 == 1:
                    removed1 += 1
                else:
                    nl.append(i)
            else:
                nl.append(i)

        if removed1 == 2:
            nl.extend(acl)
            nl.sort(reverse=True)
            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))

        if len(acl) > 0:
            acl.sort(reverse=True)
            return int(''.join(str(e) for e in acl))
        else:
            return 0

This approach kept gets stuck on a hidden test case, which means I can't work out what or why.

Based on this, I wrote up a new one:

def maxdiv3(l):
    l.sort()

    l0 = []
    l1 = []
    l2 = []
    for i in l:
        if i % 3 == 0:
            l0.append(i)
        elif i % 3 == 1:
            l1.append(i)
        elif i % 3 == 2:
            l2.append(i)

    tot = sum(l)

    nl = []

    if tot % 3 == 0:
        nl = l

        nl.sort(reverse=True)

        if len(nl) > 0:
            return int(''.join(str(e) for e in nl))
        
        return 0
    elif tot % 3 == 1:
        if len(l1) > 0:
            l1.remove(l1[0])

            nl.extend(l0)
            nl.extend(l1)
            nl.extend(l2)

            nl.sort(reverse=True)

            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))

            return 0
        elif len(l2) > 1:
            l2.remove(l2[0])
            l2.remove(l2[0])

            nl.extend(l0)
            nl.extend(l1)
            nl.extend(l2)

            nl.sort(reverse=True)

            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))

            return 0
        else:
            return 0
    elif tot % 3 == 2:
        if len(l2) > 0:
            l2.remove(l2[0])

            nl.extend(l0)
            nl.extend(l1)
            nl.extend(l2)

            nl.sort(reverse=True)

            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))

            return 0
        elif len(l1) > 1:
            l1.remove(l1[0])
            l1.remove(l1[0])

            nl.extend(l0)
            nl.extend(l1)
            nl.extend(l2)

            nl.sort(reverse=True)

            if len(nl) > 0:
                return int(''.join(str(e) for e in nl))

            return 0
        else:
            return 0

And this one does pass all the test cases, including the hidden ones.

Here are some of the test cases that I ran my attempt through:

[3, 9, 5, 2] -> 93
[1, 5, 0, 6, 3, 5, 6] -> 665310
[5, 2] -> 0
[1] -> 0
[2] -> 0
[1, 1] -> 0
[9, 5, 5] -> 9

It seems to me that my attempt and the SO solution had the same idea in mind, so what did I neglect to consider? How is the SO solution different from mine and how does that catch whatever it is that my attempt didn't?

Thank you for your time.


A slightly off topic additional question: How do I find edge cases when dealing with blind test cases? I built a random input generator for this question but that didn't help anywhere near I wished it would and is likely not a good general solution.


How do I find edge cases when dealing with blind test cases?

Here is what I did as tester:

digits = list(range(10))
for k in range(1, 10):  # Try different sizes
    for _ in range(100):  # Repeat many times
        lst = random.choices(digits, k=k)  # Produce random digits
        a = maxdiv3(lst)  # Run working solution
        b = maxdiv3b(lst)  # Run solution with a problem
        if a != b:  # Found a deviation!
            print(lst)
            break

This was one of the lists I got:

[2, 2, 5, 5, 8]

Then I retraced your code with that input and came into this block:

    else:
        if len(acl) > 0:
            acl.sort(reverse=True)
            return int(''.join(str(e) for e in acl))

We are here in the case where the total has a remainder of 1 when divided by 3. We get in this else block when there is no individual digit with such a remainder. It then just outputs all the digits that are multiples of 3. But this is not always right. When there are at least two digits with a remainder of 2 (i.e. 2, 5, 8), such a pair represents a total that has a remainder of 1, i.e. you only have to remove two digits, not more.

The correction is to remove two of those digits (the smallest) and then join the two lists as you did elsewhere in the code:

    else:
        del nl[0:2]
        nl.extend(acl)
        nl.sort(reverse=True)
        if len(nl) > 0:
            return int(''.join(str(e) for e in nl))
        else:
            return 0

NB: It didn't help that the chosen names are not very descriptive. I have no clue what acl, nl, or cl are abbreviations of.