Moqing Entity Framework 6 .Include() using DbSet<>

I'd like to give the background to this question. Skip if you like. For quite a while I've paid close attention to the on going debates on stackoverflow and elsewhere regarding testing of code as it relates to EF. One camp says, test directly against a database because of the differences between the Linq to Objects & Sql and implementations. Another says test by mocking.

Another split in opinion is the issue of using repositories, or accepting that DbContext and DbSet already provide a unit of work and repository pattern. In the time that I've been using EF, I've tried about every combination of opinions provided by these camps. Regardless of what I've done, EF proves to be hard to test.

I was excited to find the EF team made DbSet more mockable in EF 6. They also provided documentation on how to mock DbSet, including async methods using Moq. In working on my latest project involving Web Api I realized that if I could mock EF, I could skip writing repositories, as the normal reason for writing them is to make things testable. Inspiration came after reading a few blog posts such as this...

--End of background ---

The actual problem is that following the example code given by the EF team on how to Moq DbSet, if .Include() is used in any code, an ArgumentNullException is thrown.

Other related post on SO

Here is my interface for DbContext:

public interface ITubingForcesDbContext
{
    DbSet<WellEntity> Wells { get; set; }

    int SaveChanges();

    Task<int> SaveChangesAsync();

    Task<int> SaveChangesAsync(CancellationToken cancellationToken);
}

This is the main entity that my controller deals with

public class WellEntity
{
    public int Id { get; set; }
    public DateTime DateUpdated { get; set; }
    public String UpdatedBy { get; set; }

    [Required]
    public string Name { get; set; }
    public string Location { get; set; }

    public virtual Company Company { get; set; }

    public virtual ICollection<GeometryItem> GeometryItems
    {
        get { return _geometryItems ?? (_geometryItems = new Collection<GeometryItem>()); }
        protected set { _geometryItems = value; }
    }
    private ICollection<GeometryItem> _geometryItems;

    public virtual ICollection<SurveyPoint> SurveyPoints
    {
        get { return _surveyPoints ?? (_surveyPoints = new Collection<SurveyPoint>()); }
        protected set { _surveyPoints = value; }
    }
    private ICollection<SurveyPoint> _surveyPoints;

    public virtual ICollection<TemperaturePoint> TemperaturePoints
    {
        get { return _temperaturePoints ?? (_temperaturePoints = new Collection<TemperaturePoint>()); }
        protected set { _temperaturePoints = value; }
    }
    private ICollection<TemperaturePoint> _temperaturePoints;
}

Here is the controller which directly uses an EF DbContext

 [Route("{id}")]
 public async Task<IHttpActionResult> Get(int id)
 {
        var query = await TheContext.Wells.
                                   Include(x => x.GeometryItems).
                                   Include(x => x.SurveyPoints).
                                   Include(x => x.TemperaturePoints).
                                   SingleOrDefaultAsync(x => x.Id == id);
        if (query == null)
        {
            return NotFound();
        }
        var model = ModelFactory.Create(query);
        return Ok(model);
}

Finally here is the failing test...

Test Setup---

   [ClassInitialize]
   public static void ClassInitialize(TestContext testContest)
        {

            var well1 = new WellEntity { Name = "Well 1" };
            var well2 = new WellEntity { Name = "Well 2" };
            var well3 = new WellEntity { Name = "Well 3" };
            var well4 = new WellEntity { Name = "Well 4" };

            well1.GeometryItems.Add(new GeometryItem());
            well1.TemperaturePoints.Add(new TemperaturePoint());
            well1.SurveyPoints.Add(new SurveyPoint());

            well2.GeometryItems.Add(new GeometryItem());
            well2.TemperaturePoints.Add(new TemperaturePoint());
            well2.SurveyPoints.Add(new SurveyPoint());

            well3.GeometryItems.Add(new GeometryItem());
            well3.TemperaturePoints.Add(new TemperaturePoint());
            well3.SurveyPoints.Add(new SurveyPoint());

            well4.GeometryItems.Add(new GeometryItem());
            well4.TemperaturePoints.Add(new TemperaturePoint());
            well4.SurveyPoints.Add(new SurveyPoint());

            var wells = new List<WellEntity> { well1, well2, well3, well4 }.AsQueryable();

            var mockWells = CreateMockSet(wells);

            _mockContext = new Mock<ITubingForcesDbContext>();
            _mockContext.Setup(c => c.Wells).Returns(mockWells.Object);
   }

   private static Mock<DbSet<T>> CreateMockSet<T>(IQueryable<T> data) where T : class
    {
        var mockSet = new Mock<DbSet<T>>();

        mockSet.As<IDbAsyncEnumerable<T>>()
            .Setup(m => m.GetAsyncEnumerator())
            .Returns(new TestDbAsyncEnumerator<T>(data.GetEnumerator()));

        mockSet.As<IQueryable<T>>()
               .Setup(m => m.Provider)
               .Returns(new TestDbAsyncQueryProvider<T>(data.Provider));

        mockSet.As<IQueryable<T>>().Setup(m => m.Expression).Returns(data.Expression);
        mockSet.As<IQueryable<T>>().Setup(m =>m.ElementType).Returns(data.ElementType);
        mockSet.As<IQueryable<T>>().Setup(m=>m.GetEnumerator()).
        Returns(data.GetEnumerator());

        return mockSet;
   }

  [TestMethod]  
   public async Task Get_ById_ReturnsWellWithAllChildData()
    {
        // Arrange
        var controller = new WellsController(_mockContext.Object);

        // Act
        var actionResult = await controller.Get(1);

        // Assert
        var response = actionResult as OkNegotiatedContentResult<WellModel>;
        Assert.IsNotNull(response);
        Assert.IsNotNull(response.Content.GeometryItems);
        Assert.IsNotNull(response.Content.SurveyPoints);
        Assert.IsNotNull(response.Content.TemperaturePoints);
   }

TestDbAsyncQueryProvider & TestDbAsyncEnumerator come directly from the referenced EF team documentation. I've tried several different variations for how I create the data for the mock, haven't had any luck with it.


Solution 1:

For anyone who stumbles upon this issue with interest on how to solve the .Include("Foo") problem with NSubstitute and Entity Framework 6+, I was able to bypass my Include calls in the following way:

var data = new List<Foo>()
{
    /* Stub data */
}.AsQueryable();

var mockSet = Substitute.For<DbSet<Foo>, IQueryable<Foo>>();
((IQueryable<Post>)mockSet).Provider.Returns(data.Provider);
((IQueryable<Post>)mockSet).Expression.Returns(data.Expression);
((IQueryable<Post>)mockSet).ElementType.Returns(data.ElementType);
((IQueryable<Post>)mockSet).GetEnumerator().Returns(data.GetEnumerator());

// The following line bypasses the Include call.
mockSet.Include(Arg.Any<string>()).Returns(mockSet);

Solution 2:

Here is a complete example using Moq. You can paste the entire example into your unit test class. Thanks to comments by @jbaum012 and @Skuli. I also recommend the excellent tutorial from Microsoft.

// An Address entity
public class Address
{
    public int Id { get; set; }
    public string Line1 { get; set; }
}

// A Person referencing Address
public class Person
{
    public int Id { get; set; }
    public string Name { get; set; }
    public virtual Address Address { get; set; }
}

// A DbContext with persons and devices
// Note use of virtual (see the tutorial reference)
public class PersonContext : DbContext
{
    public virtual DbSet<Person> Persons { get; set; }
    public virtual DbSet<Address> Addresses { get; set; }
}

// A simple class to test
// The dbcontext is injected into the controller
public class PersonsController
{
    private readonly PersonContext _personContext;

    public PersonsController(PersonContext personContext)
    {
        _personContext = personContext;
    }

    public IEnumerable<Person> GetPersons()
    {
        return _personContext.Persons.Include("Address").ToList();
    }
}

// Test the controller above
[TestMethod]
public void GetPersonsTest()
{
    var address = new Address { Id = 1, Line1 = "123 Main St." };
    var expectedPersons = new List<Person>
    {
        new Person { Id = 1, Address = address, Name = "John" },
        new Person { Id = 2, Address = address, Name = "John Jr." },
    };

    var mockPersonSet = GetMockDbSet(expectedPersons.AsQueryable());
    mockPersonSet.Setup(m => m.Include("Address")).Returns(mockPersonSet.Object);

    var mockPersonContext = new Mock<PersonContext>();
    mockPersonContext.Setup(o => o.Persons).Returns(mockPersonSet.Object);

    // test the controller GetPersons() method, which leverages Include()
    var controller = new PersonsController(mockPersonContext.Object);
    var actualPersons = controller.GetPersons();
    CollectionAssert.AreEqual(expectedPersons, actualPersons.ToList());
}

// a helper to make dbset queryable
private Mock<DbSet<T>> GetMockDbSet<T>(IQueryable<T> entities) where T : class
{
    var mockSet = new Mock<DbSet<T>>();
    mockSet.As<IQueryable<T>>().Setup(m => m.Provider).Returns(entities.Provider);
    mockSet.As<IQueryable<T>>().Setup(m => m.Expression).Returns(entities.Expression);
    mockSet.As<IQueryable<T>>().Setup(m => m.ElementType).Returns(entities.ElementType);
    mockSet.As<IQueryable<T>>().Setup(m => m.GetEnumerator()).Returns(entities.GetEnumerator());
    return mockSet;
}