C# Linq源码分析之Take(四)

概要

本文主要对Take的优化方法进行源码分析,分析Take在配合Select,Where等常用的Linq扩展方法使用时候,如何实现优化处理。

本文涉及到Select, Where和Take和三个方法的源码分析,其中Select, Where, Take更详尽的源码分析,请参考我之前写的文章。

源码分析

我们之前对Take的源码分析,主要是真对数据序列对象直接调用的Take方法的情况。本文介绍的Take优化方法,主要是针对多个Linq方法配合使用的情况,例如xx.Where.Select.Take或者xx.Select.Take的情况。

Take的优化方法,是定义在Take.SpeedOpt.cs文件中,体现在下面的TakeIterator方法中,Take方法对TakeIterator方法的具体调用方式,请参考C# Linq源码分析之Take方法

csharp 复制代码
private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count)
 {
     Debug.Assert(source != null);
     Debug.Assert(count > 0);

     return
         source is IPartition<TSource> partition ? partition.Take(count) :
         source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :
         new EnumerablePartition<TSource>(source, 0, count - 1);
 }

该优化方法的基本逻辑如下:

  1. 如果source序列实现了IPartition接口,则调用IPartition中的Take方法;
  2. 如果source序列实现了IList接口,则返回ListPartition对象;
  3. 返回EnumerablePartition对象。

案例分析

Select.Take

该场景模拟我们显示中将EF中与数据库关联的对象,转换成Web前端需要的对象,并分页的情况。

将Student对象中的学生姓名和考试成绩取出后,并分页,Student类的定义和初始化请见附录。

csharp 复制代码
 studentList.Select(x => new {
                x.Name, x.MathResult
            }).take(3).ToList();

执行流程如上图所示:

  1. 进入Select方法后返回一个SelectListIterator对象,该对象存储了source序列数据和selector;
  2. 进入Take方法后, 因为SelectListIterator实现了IPartition接口,因此可以调用自己的Take方法,使用source,selector和count(Take方法的参数)实例化SelectListPartitionIterator对象;
csharp 复制代码
  public IPartition<TResult> Take(int count)
{
    Debug.Assert(count > 0);
    int maxIndex = _minIndexInclusive + count - 1;
    return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, maxIndex);
}
  1. 进入ToList方法后,根据下面的代码,因为SelectListPartitionIterator对象实现了IPartition接口,而IPartition又继承了IIListProvider接口,所以listProvider.ToList()被执行。
csharp 复制代码
   public static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
        {
if (source == null)
 {
     ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
 }

 return source is IIListProvider<TSource> listProvider ? listProvider.ToList() : new List<TSource>(source);
}

listProvider.ToList() 被执行,即SelectListPartitionIterator对象内定义的ToList()方法被执行,该ToList方法可以将

Selecxt的投影操作,Take的取值操作同时执行,代码如下:

csharp 复制代码
 public List<TResult> ToList()
{
    int count = Count;
    if (count == 0)
    {
        return new List<TResult>();
    }

    List<TResult> list = new List<TResult>(count);
    int end = _minIndexInclusive + count;
    for (int i = _minIndexInclusive; i != end; ++i)
    {
        list.Add(_selector(_source[i]));
    }

    return list;
}
  1. 如果Take的取值为0,直接返回空List,不会进行Select操作;
  2. 按照Take中指定的Count,取到相应的元素,再进行投影操作;
  3. 返回操作结果。

这样,源List序列source,Linq只遍历和一次,就同时完成了投影和过滤两个操作,实现了优化。

附录

Student类

csharp 复制代码
public class Student {
    public string Id { get; set; }
    public string Name { get; set; }
    public string Classroom { get; set; }
    public int MathResult { get; set; }
}

IIListProvider接口

csharp 复制代码
internal interface IIListProvider<TElement> : IEnumerable<TElement>
{
    TElement[] ToArray();
    List<TElement> ToList();
    int GetCount(bool onlyIfCheap);
}

IPartition接口

csharp 复制代码
internal interface IPartition<TElement> : IIListProvider<TElement>
{

    IPartition<TElement> Skip(int count);

    IPartition<TElement> Take(int count);

    TElement? TryGetElementAt(int index, out bool found);

    TElement? TryGetFirst(out bool found);

    TElement? TryGetLast(out bool found);
}

SelectListPartitionIterator类

csharp 复制代码
 private sealed class SelectListPartitionIterator<TSource, TResult> : Iterator<TResult>, IPartition<TResult>
{
    private readonly IList<TSource> _source;
    private readonly Func<TSource, TResult> _selector;
    private readonly int _minIndexInclusive;
    private readonly int _maxIndexInclusive;

    public SelectListPartitionIterator(IList<TSource> source, Func<TSource, TResult> selector, int minIndexInclusive, int maxIndexInclusive)
    {
        Debug.Assert(source != null);
        Debug.Assert(selector != null);
        Debug.Assert(minIndexInclusive >= 0);
        Debug.Assert(minIndexInclusive <= maxIndexInclusive);
        _source = source;
        _selector = selector;
        _minIndexInclusive = minIndexInclusive;
        _maxIndexInclusive = maxIndexInclusive;
    }

    public override Iterator<TResult> Clone() =>
        new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, _maxIndexInclusive);

    public override bool MoveNext()
    {
        // _state - 1 represents the zero-based index into the list.
        // Having a separate field for the index would be more readable. However, we save it
        // into _state with a bias to minimize field size of the iterator.
        int index = _state - 1;
        if (unchecked((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive))
        {
            _current = _selector(_source[_minIndexInclusive + index]);
            ++_state;
            return true;
        }

        Dispose();
        return false;
    }

    public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
        new SelectListPartitionIterator<TSource, TResult2>(_source, CombineSelectors(_selector, selector), _minIndexInclusive, _maxIndexInclusive);

    public IPartition<TResult> Skip(int count)
    {
        Debug.Assert(count > 0);
        int minIndex = _minIndexInclusive + count;
        return (uint)minIndex > (uint)_maxIndexInclusive ? EmptyPartition<TResult>.Instance : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, minIndex, _maxIndexInclusive);
    }

    public IPartition<TResult> Take(int count)
    {
        Debug.Assert(count > 0);
        int maxIndex = _minIndexInclusive + count - 1;
        return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, maxIndex);
    }

    public TResult? TryGetElementAt(int index, out bool found)
    {
        if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive)
        {
            found = true;
            return _selector(_source[_minIndexInclusive + index]);
        }

        found = false;
        return default;
    }

    public TResult? TryGetFirst(out bool found)
    {
        if (_source.Count > _minIndexInclusive)
        {
            found = true;
            return _selector(_source[_minIndexInclusive]);
        }

        found = false;
        return default;
    }

    public TResult? TryGetLast(out bool found)
    {
        int lastIndex = _source.Count - 1;
        if (lastIndex >= _minIndexInclusive)
        {
            found = true;
            return _selector(_source[Math.Min(lastIndex, _maxIndexInclusive)]);
        }

        found = false;
        return default;
    }

    private int Count
    {
        get
        {
            int count = _source.Count;
            if (count <= _minIndexInclusive)
            {
                return 0;
            }

            return Math.Min(count - 1, _maxIndexInclusive) - _minIndexInclusive + 1;
        }
    }

    public TResult[] ToArray()
    {
        int count = Count;
        if (count == 0)
        {
            return Array.Empty<TResult>();
        }

        TResult[] array = new TResult[count];
        for (int i = 0, curIdx = _minIndexInclusive; i != array.Length; ++i, ++curIdx)
        {
            array[i] = _selector(_source[curIdx]);
        }

        return array;
    }

    public List<TResult> ToList()
    {
        int count = Count;
        if (count == 0)
        {
            return new List<TResult>();
        }

        List<TResult> list = new List<TResult>(count);
        int end = _minIndexInclusive + count;
        for (int i = _minIndexInclusive; i != end; ++i)
        {
            list.Add(_selector(_source[i]));
        }

        return list;
    }

    public int GetCount(bool onlyIfCheap)
    {
        // In case someone uses Count() to force evaluation of
        // the selector, run it provided `onlyIfCheap` is false.

        int count = Count;

        if (!onlyIfCheap)
        {
            int end = _minIndexInclusive + count;
            for (int i = _minIndexInclusive; i != end; ++i)
            {
                _selector(_source[i]);
            }
        }

        return count;
    }
}

SelectListIterator类 Select.cs

csharp 复制代码
 private sealed partial class SelectListIterator<TSource, TResult> : Iterator<TResult>
 {
     private readonly List<TSource> _source;
     private readonly Func<TSource, TResult> _selector;
     private List<TSource>.Enumerator _enumerator;

     public SelectListIterator(List<TSource> source, Func<TSource, TResult> selector)
     {
         Debug.Assert(source != null);
         Debug.Assert(selector != null);
         _source = source;
         _selector = selector;
     }

     private int CountForDebugger => _source.Count;

     public override Iterator<TResult> Clone() => new SelectListIterator<TSource, TResult>(_source, _selector);

     public override bool MoveNext()
     {
         switch (_state)
         {
             case 1:
                 _enumerator = _source.GetEnumerator();
                 _state = 2;
                 goto case 2;
             case 2:
                 if (_enumerator.MoveNext())
                 {
                     _current = _selector(_enumerator.Current);
                     return true;
                 }

                 Dispose();
                 break;
         }

         return false;
     }

     public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
         new SelectListIterator<TSource, TResult2>(_source, CombineSelectors(_selector, selector));
 }

Select.SpeedOpt.cs

csharp 复制代码
private sealed partial class SelectListIterator<TSource, TResult> : IPartition<TResult>
 {
      public TResult[] ToArray()
      {
          int count = _source.Count;
          if (count == 0)
          {
              return Array.Empty<TResult>();
          }

          var results = new TResult[count];
          for (int i = 0; i < results.Length; i++)
          {
              results[i] = _selector(_source[i]);
          }

          return results;
      }

      public List<TResult> ToList()
      {
          int count = _source.Count;
          var results = new List<TResult>(count);
          for (int i = 0; i < count; i++)
          {
              results.Add(_selector(_source[i]));
          }

          return results;
      }

      public int GetCount(bool onlyIfCheap)
      {
          // In case someone uses Count() to force evaluation of
          // the selector, run it provided `onlyIfCheap` is false.

          int count = _source.Count;

          if (!onlyIfCheap)
          {
              for (int i = 0; i < count; i++)
              {
                  _selector(_source[i]);
              }
          }

          return count;
      }

      public IPartition<TResult> Skip(int count)
      {
          Debug.Assert(count > 0);
          return new SelectListPartitionIterator<TSource, TResult>(_source, _selector, count, int.MaxValue);
      }

      public IPartition<TResult> Take(int count)
      {
          Debug.Assert(count > 0);
          return new SelectListPartitionIterator<TSource, TResult>(_source, _selector, 0, count - 1);
      }

      public TResult? TryGetElementAt(int index, out bool found)
      {
          if (unchecked((uint)index < (uint)_source.Count))
          {
              found = true;
              return _selector(_source[index]);
          }

          found = false;
          return default;
      }

      public TResult? TryGetFirst(out bool found)
      {
          if (_source.Count != 0)
          {
              found = true;
              return _selector(_source[0]);
          }

          found = false;
          return default;
      }

      public TResult? TryGetLast(out bool found)
      {
          int len = _source.Count;
          if (len != 0)
          {
              found = true;
              return _selector(_source[len - 1]);
          }

          found = false;
          return default;
      }
  }
相关推荐
BinaryBardC1 小时前
Swift语言的网络编程
开发语言·后端·golang
code_shenbing1 小时前
基于 WPF 平台使用纯 C# 制作流体动画
开发语言·c#·wpf
code_shenbing1 小时前
基于 WPF 平台实现成语游戏
游戏·c#·wpf
邓熙榆1 小时前
Haskell语言的正则表达式
开发语言·后端·golang
ac-er88882 小时前
Yii框架中的队列:如何实现异步操作
android·开发语言·php
马船长2 小时前
青少年CTF练习平台 PHP的后门
开发语言·php
hefaxiang3 小时前
【C++】函数重载
开发语言·c++·算法
落幕4 小时前
C语言-构造数据类型
c语言·开发语言
勤又氪猿4 小时前
【问题】Qt c++ 界面 lineEdit、comboBox、tableWidget.... SIGSEGV错误
开发语言·c++·qt
Ciderw4 小时前
Go中的三种锁
开发语言·c++·后端·golang·互斥锁·