diff --git a/Project.toml b/Project.toml index 15aa08e0..3231c2b9 100644 --- a/Project.toml +++ b/Project.toml @@ -6,8 +6,12 @@ version = "0.8.0" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ScanByte = "7b38b023-a4d7-4c5e-8d43-3f3097f304eb" TranscodingStreams = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +[compat] +ScanByte = "0.3" + [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/Automa.jl b/src/Automa.jl index e1278148..5b0c14d4 100644 --- a/src/Automa.jl +++ b/src/Automa.jl @@ -2,6 +2,7 @@ module Automa using DataStructures: DefaultDict using Printf: @sprintf +import ScanByte include("sdict.jl") include("sset.jl") diff --git a/src/byteset.jl b/src/byteset.jl index a4fadf68..532f9176 100644 --- a/src/byteset.jl +++ b/src/byteset.jl @@ -29,6 +29,8 @@ function ByteSet(bytes::Union{UInt8,AbstractVector{UInt8},Set{UInt8}}) return ByteSet(a, b, c, d) end +ByteSet(x::ScanByte.ByteSet) = ByteSet(x.data...) + function Base.:(==)(s1::ByteSet, s2::ByteSet) return s1.a == s2.a && s1.b == s2.b && s1.c == s2.c && s1.d == s2.d end @@ -119,6 +121,9 @@ function Base.maximum(set::ByteSet) end end +Base.:~(x::ByteSet) = ByteSet(~x.a, ~x.b, ~x.c, ~x.d) +iscontiguous(x::ByteSet) = maximum(x) - minimum(x) == length(x) - 1 + function isdisjoint(s1::ByteSet, s2::ByteSet) return isempty(intersect(s1, s2)) end diff --git a/src/codegen.jl b/src/codegen.jl index a82a4ff1..61f5738e 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -71,6 +71,16 @@ function CodeGenContext(; elseif loopunroll > 0 && generator != :goto throw(ArgumentError("loop unrolling is not supported for $(generator)")) end + # special conditions for simd generator + if generator == :simd + if loopunroll != 0 + throw(ArgumentError("SIMD generator does not support unrolling")) + elseif getbyte != Base.getindex + throw(ArgumentError("SIMD generator only support Base.getindex")) + elseif checkbounds + throw(ArgumentError("SIMD generator does not support boundscheck")) + end + end # check generator if generator == :table generator = generate_table_code @@ -78,6 +88,8 @@ function CodeGenContext(; generator = generate_inline_code elseif generator == :goto generator = generate_goto_code + elseif generator == :simd + generator = generate_simd_code else throw(ArgumentError("invalid code generator: $(generator)")) end @@ -299,6 +311,101 @@ function generate_goto_code(ctx::CodeGenContext, machine::Machine, actions::Dict end end +function generate_simd_code(ctx::CodeGenContext, machine::Machine, actions::Dict{Symbol,Expr}) + ## SAME AS GOTO BEGIN + actions_in = Dict{Node,Set{Vector{Symbol}}}() + for s in traverse(machine.start), (e, t) in s.edges + push!(get!(actions_in, t, Set{Vector{Symbol}}()), action_names(e.actions)) + end + action_label = Dict{Node,Dict{Vector{Symbol},Symbol}}() + for s in traverse(machine.start) + action_label[s] = Dict() + if haskey(actions_in, s) + for (i, names) in enumerate(actions_in[s]) + action_label[s][names] = Symbol("state_", s.state, "_action_", i) + end + end + end + + blocks = Expr[] + for s in traverse(machine.start) + block = Expr(:block) + for (names, label) in action_label[s] + if isempty(names) + continue + end + append_code!(block, quote + @label $(label) + $(rewrite_special_macros(ctx, generate_action_code(names, actions), false, s.state)) + @goto $(Symbol("state_", s.state)) + end) + end + + append_code!(block, quote + @label $(Symbol("state_", s.state)) + $(ctx.vars.p) += 1 + if $(ctx.vars.p) > $(ctx.vars.p_end) + $(ctx.vars.cs) = $(s.state) + @goto exit + end + end) + + ### END SAME + simd, non_simd = peel_simd_edge(s) + simd_code = if simd !== nothing + quote + $(generate_simd_loop(ctx, simd.labels)) + if $(ctx.vars.p) > $(ctx.vars.p_end) + $(ctx.vars.cs) = $(s.state) + @goto exit + end + end + else + :() + end + + default = :($(ctx.vars.cs) = $(-s.state); @goto exit) + dispatch_code = foldr(default, optimize_edge_order(non_simd)) do edge, els + e, t = edge + if isempty(e.actions) + then = :(@goto $(Symbol("state_", t.state))) + else + then = :(@goto $(action_label[t][action_names(e.actions)])) + end + return Expr(:if, generate_condition_code(ctx, e, actions), then, els) + end + # BEGIN SAME AGAIN + append_code!(block, quote + @label $(Symbol("state_case_", s.state)) + $(simd_code) + $(generate_geybyte_code(ctx)) + $(dispatch_code) + end) + push!(blocks, block) + end + + enter_code = foldr(:(@goto exit), machine.states) do s, els + return Expr(:if, :($(ctx.vars.cs) == $(s)), :(@goto $(Symbol("state_case_", s))), els) + end + + eof_action_code = rewrite_special_macros(ctx, generate_eof_action_code(ctx, machine, actions), true) + + return quote + if $(ctx.vars.p) > $(ctx.vars.p_end) + @goto exit + end + $(ctx.vars.mem) = $(SizedMemory)($(ctx.vars.data)) + $(enter_code) + $(Expr(:block, blocks...)) + @label exit + if $(ctx.vars.p) > $(ctx.vars.p_eof) ≥ 0 && $(ctx.vars.cs) ∈ $(machine.final_states) + $(eof_action_code) + $(ctx.vars.cs) = 0 + end + end +end + + function append_code!(block::Expr, code::Expr) @assert block.head == :block @assert code.head == :block @@ -339,6 +446,29 @@ function generate_unrolled_loop(ctx::CodeGenContext, edge::Edge, t::Node) end end +# Note: This function has been carefully crafted to produce (nearly) optimal +# assembly code for AVX2-capable CPUs. Change with great care. +function generate_simd_loop(ctx::CodeGenContext, bs::ByteSet) + byteset = ~ScanByte.ByteSet(bs) + bsym = gensym() + quote + $bsym = Automa.loop_simd( + $(ctx.vars.mem).ptr + $(ctx.vars.p) - 1, + ($(ctx.vars.p_end) - $(ctx.vars.p) + 1) % UInt, + Val($byteset) + ) + $(ctx.vars.p) = if $bsym === nothing + $(ctx.vars.p_end) + 1 + else + $(ctx.vars.p) + $bsym - 1 + end + end +end + +@inline function loop_simd(ptr::Ptr, len::UInt, valbs::Val) + ScanByte.memchr(ptr, len, valbs) +end + function generate_eof_action_code(ctx::CodeGenContext, machine::Machine, actions::Dict{Symbol,Expr}) return foldr(:(), machine.eof_actions) do s_as, els s, as = s_as @@ -492,6 +622,21 @@ function debug_actions(machine::Machine) return Dict{Symbol,Expr}(name => log_expr(name) for name in actions) end +"If possible, remove self-simd edge." +function peel_simd_edge(node) + non_simd = Tuple{Edge, Node}[] + simd = nothing + for (e, t) in node.edges + if t === node && isempty(e.actions) && isempty(e.precond) + simd = e + else + push!(non_simd, (e, t)) + end + end + return simd, non_simd +end + + # Sort edges by its size in descending order. function optimize_edge_order(edges) return sort!(copy(edges), by=e->length(e[1].labels), rev=true) diff --git a/src/sset.jl b/src/sset.jl index 2b8bdfb4..2acba903 100644 --- a/src/sset.jl +++ b/src/sset.jl @@ -83,16 +83,6 @@ function Base.union(set::StableSet, xs) return union!(copy(set), xs) end -function Base.filter(f::Function, set::StableSet) - newset = Set{eltype(set)}() - for x in set - if f(x) - push!(newset, x) - end - end - return newset -end - function Base.iterate(set::StableSet, s=iterate(set.dict)) if s == nothing return nothing diff --git a/test/runtests.jl b/test/runtests.jl index 5cb96b49..7dba443b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,13 @@ using Test @test_throws BoundsError mem[4] end +@testset "ByteSet" begin + x = Automa.ByteSet() + @test isempty(x) + @test_throws ArgumentError minimum(x) + @test_throws ArgumentError maximum(x) +end + @testset "RegExp" begin @test_throws ArgumentError("invalid escape sequence: \\o") Automa.RegExp.parse("\\o") end @@ -78,6 +85,7 @@ include("test16.jl") include("test17.jl") include("test18.jl") include("test19.jl") +include("simd.jl") module TestFASTA using Test diff --git a/test/simd.jl b/test/simd.jl new file mode 100644 index 00000000..d3acb1fd --- /dev/null +++ b/test/simd.jl @@ -0,0 +1,38 @@ +# Test codegencontext +@testset "CodeGenContext" begin + @test_throws ArgumentError Automa.CodeGenContext(generator=:fdjfhkdj) + @test_throws ArgumentError Automa.CodeGenContext(generator=:simd) + @test_throws ArgumentError Automa.CodeGenContext(generator=:simd, checkbounds=false, loopunroll=2) + @test_throws ArgumentError Automa.CodeGenContext(generator=:simd, checkbounds=false, getbyte=identity) +end + +import Automa +const re = Automa.RegExp +import Automa.RegExp: @re_str + +@testset "SIMD generator" begin + machine = let + seq = re"[A-Z]+" + name = re"[a-z]+" + rec = re">" * name * re"\n" * seq + Automa.compile(re.opt(rec) * re.rep(re"\n" * rec)) + end + + context = Automa.CodeGenContext(generator=:simd, checkbounds=false) + + @eval function is_valid_fasta(data::String) + $(Automa.generate_init_code(context, machine)) + p_end = p_eof = ncodeunits(data) + $(Automa.generate_exec_code(context, machine, nothing)) + return p == ncodeunits(data) + 1 + end + + s1 = ">seq\nTAGGCTA\n>hello\nAJKGMP" + s2 = ">seq1\nTAGGC" + s3 = ">verylongsequencewherethesimdkicksin\nQ" + + for (seq, isvalid) in [(s1, true), (s2, false), (s3, true)] + @test is_valid_fasta(seq) == isvalid + end +end +