import z3

# Exercise 2.1 through SMT solving
if __name__ == "__main__":
    P = {"p1", "p2", "p3", "p4"}
    T = {"t1", "t2", "t3", "t4"}
    
    solver = z3.Solver()

    # Declare variables    
    u = {p: z3.Real(f"u_{p}") for p in P}
    v = {p: z3.Real(f"v_{p}") for p in P}
    x = {t: z3.Real(f"x_{t}") for t in T}
    
    z_fwd = {v: z3.Real( f"z_fwd_{v}") for v in P | T}
    z_bwd = {v: z3.Real( f"z_bwd_{v}") for v in P | T}

    # Non negativity constraints
    solver.add(z3.And([u[p] >= 0 for p in P]))
    solver.add(z3.And([v[p] >= 0 for p in P]))
    solver.add(z3.And([x[t] >= 0 for t in T]))

    solver.add(z3.And([z_fwd[v] >= 0 for v in P | T]))
    solver.add(z3.And([z_bwd[v] >= 0 for v in P | T]))

    # Marking equation
    solver.add(u["p1"] - x["t1"] - 2*x["t2"] - x["t3"]           == v["p1"])
    solver.add(u["p2"] + x["t1"]             - x["t3"]           == v["p2"])
    solver.add(u["p3"]           +   x["t2"] + x["t3"] - x["t4"] == v["p3"])
    solver.add(u["p4"]                                 + x["t4"] == v["p4"])

    # Forward firability constraints
    ## Transitions
    solver.add(z3.Implies(x["t1"] > 0,
                          z3.And(z_fwd["t1"] > 0,
                                 z_fwd["p1"] > 0,
                                 z_fwd["p1"] < z_fwd["t1"])))

    solver.add(z3.Implies(x["t2"] > 0,
                          z3.And(z_fwd["t2"] > 0,
                                 z_fwd["p1"] > 0,
                                 z_fwd["p1"] < z_fwd["t2"],
                                 z_fwd["p4"] > 0,
                                 z_fwd["p4"] < z_fwd["t2"])))
    
    solver.add(z3.Implies(x["t3"] > 0,
                          z3.And(z_fwd["t3"] > 0,
                                 z_fwd["p1"] > 0,
                                 z_fwd["p1"] < z_fwd["t3"],
                                 z_fwd["p2"] > 0,
                                 z_fwd["p2"] < z_fwd["t3"])))
    
    solver.add(z3.Implies(x["t4"] > 0,
                          z3.And(z_fwd["t4"] > 0,
                                 z_fwd["p3"] > 0,
                                 z_fwd["p3"] < z_fwd["t4"])))

    ## Places
    solver.add(z3.Implies(z_fwd["p1"] > 0,
                          z3.Or(u["p1"] > 0,
                                z3.And(x["t3"] > 0,
                                       z_fwd["t3"] < z_fwd["p1"]))))

    solver.add(z3.Implies(z_fwd["p2"] > 0,
                          z3.Or(u["p2"] > 0,
                                z3.And(x["t1"] > 0,
                                    z_fwd["t1"] < z_fwd["p2"]))))

    solver.add(z3.Implies(z_fwd["p3"] > 0,
                          z3.Or(u["p3"] > 0,
                                z3.And(x["t2"] > 0,
                                       z_fwd["t2"] < z_fwd["p3"]),
                                z3.And(x["t3"] > 0,
                                       z_fwd["t3"] < z_fwd["p3"]))))
    
    solver.add(z3.Implies(z_fwd["p4"] > 0,
                          z3.Or(u["p4"] > 0,
                                z3.And(x["t2"] > 0,
                                       z_fwd["t2"] < z_fwd["p4"]),
                                z3.And(x["t4"] > 0,
                                       z_fwd["t4"] < z_fwd["p4"]))))

    # Backward firability constraints
    ## Transitions
    solver.add(z3.Implies(x["t1"] > 0,
                          z3.And(z_bwd["t1"] > 0,
                                 z_bwd["p2"] > 0,
                                 z_bwd["p2"] < z_bwd["t1"])))

    solver.add(z3.Implies(x["t2"] > 0,
                          z3.And(z_bwd["t2"] > 0,
                                 z_bwd["p3"] > 0,
                                 z_bwd["p3"] < z_bwd["t2"],
                                 z_bwd["p4"] > 0,
                                 z_bwd["p4"] < z_bwd["t2"])))
    
    solver.add(z3.Implies(x["t3"] > 0,
                          z3.And(z_bwd["t3"] > 0,
                                 z_bwd["p1"] > 0,
                                 z_bwd["p1"] < z_bwd["t3"],
                                 z_bwd["p3"] > 0,
                                 z_bwd["p3"] < z_bwd["t3"])))
    
    solver.add(z3.Implies(x["t4"] > 0,
                          z3.And(z_bwd["t4"] > 0,
                                 z_bwd["p4"] > 0,
                                 z_bwd["p4"] < z_bwd["t4"])))

    ## Places
    solver.add(z3.Implies(z_bwd["p1"] > 0,
                          z3.Or(v["p1"] > 0,
                                z3.And(x["t1"] > 0,
                                       z_bwd["t1"] < z_bwd["p1"]),
                                z3.And(x["t2"] > 0,
                                       z_bwd["t2"] < z_bwd["p1"]),
                                z3.And(x["t3"] > 0,
                                       z_bwd["t3"] < z_bwd["p1"]))))

    solver.add(z3.Implies(z_bwd["p2"] > 0,
                          z3.Or(v["p2"] > 0,
                                z3.And(x["t3"] > 0,
                                    z_bwd["t3"] < z_bwd["p2"]))))

    solver.add(z3.Implies(z_bwd["p3"] > 0,
                          z3.Or(v["p3"] > 0,
                                z3.And(x["t4"] > 0,
                                       z_bwd["t4"] < z_bwd["p3"]))))
    
    solver.add(z3.Implies(z_bwd["p4"] > 0,
                          z3.Or(v["p4"] > 0,
                                z3.And(x["t2"] > 0,
                                       z_bwd["t2"] < z_bwd["p4"]))))
    
    # Check continuous reachability
    def continuous_reach(m_src, m_tgt):
        # Set markings
        solver.push()
        
        solver.add(z3.And([u[p] == m_src[p] for p in P]))
        solver.add(z3.And([v[p] == m_tgt[p] for p in P]))
        
        result = solver.check()

        solver.pop()

        if result == z3.sat:
            model = solver.model()

            print("Yes, e.g. with x =", [model[x[t]] for t in sorted(T)])
        else:
            print("No")

    # First test: Exercise 2.1 (a)
    m_src = {"p1": 2, "p2": 0, "p3": 0, "p4": 0}
    m_tgt = {"p1": 0, "p2": 0, "p3": 0, "p4": 1}

    print("(2, 0, 0, 0) —*—>> (0, 0, 0, 1)?")
    print()
    continuous_reach(m_src, m_tgt)
    print()
    print()

    # Second test: Exercise 2.1 (b)
    m_src = {"p1": 2, "p2": 0, "p3": 0, "p4": 0}
    m_tgt = {"p1": 0, "p2": 0, "p3": 1, "p4": 0}

    print("(2, 0, 0, 0) —*—>> (0, 0, 1, 0)?")
    print()
    continuous_reach(m_src, m_tgt)
    print()
    print()

    # Third test: Exercise 2.1 (c)
    m_src = {"p1": 2, "p2": 0, "p3": 0, "p4": 0}
    m_tgt = {"p1": 1, "p2": 0, "p3": 1, "p4": 0}

    print("(2, 0, 0, 0) —*—>> (1, 0, 1, 0)?")
    print()
    continuous_reach(m_src, m_tgt)
