LCOV - code coverage report
Current view: top level - smt/include/mcrl2/smt - unfold_pattern_matching.h (source / functions) Hit Total Coverage
Test: mcrl2_coverage.info.cleaned Lines: 0 157 0.0 %
Date: 2024-04-26 03:18:02 Functions: 0 10 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : // Author(s): Ruud Koolen, Thomas Neele
       2             : // Copyright: see the accompanying file COPYING or copy at
       3             : // https://github.com/mCRL2org/mCRL2/blob/master/COPYING
       4             : //
       5             : // Distributed under the Boost Software License, Version 1.0.
       6             : // (See accompanying file LICENSE_1_0.txt or copy at
       7             : // http://www.boost.org/LICENSE_1_0.txt)
       8             : 
       9             : #ifndef MCRL2_SMT_UNFOLD_PATTERN_MATCHING_H
      10             : #define MCRL2_SMT_UNFOLD_PATTERN_MATCHING_H
      11             : 
      12             : #include "mcrl2/data/join.h"
      13             : #include "mcrl2/data/replace.h"
      14             : #include "mcrl2/data/representative_generator.h"
      15             : #include "mcrl2/data/unfold_pattern_matching.h"
      16             : #include "mcrl2/data/substitutions/map_substitution.h"
      17             : #include "mcrl2/data/substitutions/variable_substitution.h"
      18             : #include "mcrl2/smt/utilities.h"
      19             : 
      20             : namespace mcrl2
      21             : {
      22             : namespace smt
      23             : {
      24             : 
      25             : /**
      26             :  * \brief Contains information on sorts that behave similar to a structured sort in a data specification.
      27             :  * \details That is, there is a number of constructors, and for some constructors we have a recogniser function
      28             :  * and several projection functions.
      29             :  */
      30             : struct structured_sort_functions
      31             : {
      32             :   std::map< data::sort_expression, std::set<data::function_symbol> > constructors;
      33             :   std::map< data::function_symbol, data::function_symbol > recogniser_func;
      34             :   std::map< data::function_symbol, data::function_symbol_vector > projection_func;
      35             : 
      36           0 :   bool is_constructor(const data::function_symbol& f) const
      37             :   {
      38           0 :     const auto& cons_s = constructors.at(f.sort().target_sort());
      39           0 :     return cons_s.find(f) != cons_s.end();
      40             :   }
      41             : 
      42           0 :   const std::set<data::function_symbol>& get_constructors(const data::sort_expression& sort) const
      43             :   {
      44           0 :     return constructors.at(sort);
      45             :   }
      46             : 
      47           0 :   data::data_expression create_recogniser_expr(const data::function_symbol& f, const data::data_expression& expr) const
      48             :   {
      49           0 :     return data::application(recogniser_func.at(f), expr);
      50             :   }
      51             : 
      52           0 :   data::data_expression create_cases(const data::data_expression& target, const data::data_expression_vector& rhss)
      53             :   {
      54           0 :     std::set<data::function_symbol> constr = get_constructors(target.sort());
      55           0 :     auto const_it = constr.begin();
      56           0 :     auto rhs_it = rhss.begin();
      57           0 :     data::data_expression result = *rhs_it++;
      58           0 :     for (const_it++; const_it != constr.end(); ++const_it, ++rhs_it)
      59             :     {
      60           0 :       data::data_expression term = *rhs_it;
      61           0 :       data::data_expression condition = create_recogniser_expr(*const_it, target);
      62           0 :       result = data::lazy::if_(condition, term, result);
      63           0 :     }
      64           0 :     return result;
      65           0 :   }
      66             : 
      67           0 :   const data::function_symbol_vector& get_projection_funcs(const data::function_symbol& f) const
      68             :   {
      69           0 :     return projection_func.at(f);
      70             :   }
      71             : };
      72             : 
      73             : template <typename T>
      74             : struct always_false
      75             : {
      76             :   bool operator()(const T&)
      77             :   {
      78             :     return false;
      79             :   }
      80             : };
      81             : 
      82             : /// @brief Find sorts that behave like a structured sort and the associated rewrite rules
      83             : /// @tparam Skip Unary Boolean function type.
      84             : /// @param dataspec The data specification to consider
      85             : /// @param skip If skip(f) is true then function symbol f will not be considered
      86             : /// @return A pair containing: (1) recogniser and projection function symbols for each structured sort and
      87             : ///         (2) a map that gives a list of equations for each function symbol.
      88             : template <typename Skip = always_false<data::function_symbol>>
      89             : std::pair<structured_sort_functions, std::map< data::function_symbol, data::data_equation_vector >>
      90           0 : find_structured_sort_functions(const data::data_specification& dataspec, Skip skip = Skip())
      91             : {
      92           0 :   structured_sort_functions ssf;
      93           0 :   for(const data::sort_expression& s: dataspec.sorts())
      94             :   {
      95           0 :     ssf.constructors[s] = std::set<data::function_symbol>(dataspec.constructors(s).begin(), dataspec.constructors(s).end());
      96             :   }
      97             : 
      98           0 :   std::map< data::function_symbol, data::data_equation_vector > rewrite_rules;
      99           0 :   for (const data::data_equation& eqn: dataspec.equations())
     100             :   {
     101           0 :     data::data_expression lhs = eqn.lhs();
     102           0 :     data::function_symbol function = data::detail::get_top_fs(lhs);
     103           0 :     if (function == data::function_symbol())
     104             :     {
     105           0 :       continue;
     106             :     }
     107           0 :     if (skip(function))
     108             :     {
     109           0 :       continue;
     110             :     }
     111             :     //TODO equations of the shape x < x or x <= x are simply removed, so the remaining equations
     112             :     // form a valid pattern matching. How can this problem be adressed in a proper way?
     113           0 :     if(eqn.variables().size() == 1 && (core::pp(function.name()) == "<" || core::pp(function.name()) == "<="))
     114             :     {
     115           0 :       continue;
     116             :     }
     117             : 
     118           0 :     rewrite_rules[function].push_back(eqn);
     119             :   }
     120             : 
     121             :   // For each mapping, find out whether it is a recogniser or projection function.
     122           0 :   for (const auto& [mapping, equations]: rewrite_rules)
     123             :   {
     124           0 :     if (!data::is_function_sort(mapping.sort()))
     125             :     {
     126           0 :       continue;
     127             :     }
     128           0 :     data::function_sort sort(mapping.sort());
     129           0 :     if (sort.domain().size() != 1)
     130             :     {
     131           0 :       continue;
     132             :     }
     133           0 :     data::sort_expression domain = sort.domain().front();
     134           0 :     if (ssf.constructors[domain].empty())
     135             :     {
     136           0 :       continue;
     137             :     }
     138             : 
     139             :     // TODO implement this using a rewriter, which is a much easier way the find
     140             :     // the same patterns that are implemented manually below.
     141             :     // Check for recognisers.
     142           0 :     if (data::sort_bool::is_bool(sort.codomain()))
     143             :     {
     144           0 :       std::set<data::function_symbol> positive_recogniser_equation_seen;
     145           0 :       std::set<data::function_symbol> negative_recogniser_equation_seen;
     146           0 :       bool invalid_equations_seen = false;
     147           0 :       for (const data::data_equation& eqn: equations)
     148             :       {
     149           0 :         if (eqn.condition() != data::sort_bool::true_() ||
     150           0 :             !data::is_application(eqn.lhs()))
     151             :         {
     152           0 :           invalid_equations_seen = true;
     153           0 :           break;
     154             :         }
     155             : 
     156           0 :         data::application application(eqn.lhs());
     157           0 :         assert(application.head() == mapping);
     158           0 :         assert(application.size() == 1);
     159           0 :         data::data_expression argument(application[0]);
     160           0 :         data::function_symbol constructor = data::detail::get_top_fs(argument);
     161           0 :         if (constructor == data::function_symbol())
     162             :         {
     163           0 :           invalid_equations_seen = true;
     164           0 :           break;
     165             :         }
     166           0 :         if (data::is_application(argument))
     167             :         {
     168           0 :           const data::application& constructor_application = atermpp::down_cast<data::application>(argument);
     169           0 :           bool all_args_are_vars = std::all_of(constructor_application.begin(), constructor_application.end(), &data::is_variable);
     170           0 :           bool all_vars_are_unique = data::find_all_variables(constructor_application).size() == constructor_application.size();
     171           0 :           if(!all_args_are_vars || !all_vars_are_unique)
     172             :           {
     173           0 :             invalid_equations_seen = true;
     174           0 :             break;
     175             :           }
     176             :         }
     177             :         // Check if the function symbol we found is really a constructor
     178           0 :         if (ssf.constructors[domain].count(constructor) == 0)
     179             :         {
     180           0 :           invalid_equations_seen = true;
     181           0 :           break;
     182             :         }
     183             : 
     184           0 :         if (eqn.rhs() == data::sort_bool::true_())
     185             :         {
     186           0 :           positive_recogniser_equation_seen.insert(constructor);
     187           0 :           if (negative_recogniser_equation_seen.count(constructor) != 0)
     188             :           {
     189           0 :             invalid_equations_seen = true;
     190           0 :             break;
     191             :           }
     192             :         }
     193           0 :         else if (eqn.rhs() == data::sort_bool::false_())
     194             :         {
     195           0 :           negative_recogniser_equation_seen.insert(constructor);
     196           0 :           if (positive_recogniser_equation_seen.count(constructor) != 0)
     197             :           {
     198           0 :             invalid_equations_seen = true;
     199           0 :             break;
     200             :           }
     201             :         }
     202             :         else
     203             :         {
     204           0 :           invalid_equations_seen = true;
     205           0 :           break;
     206             :         }
     207             :       }
     208           0 :       if (!invalid_equations_seen &&
     209           0 :           positive_recogniser_equation_seen.size() == 1 &&
     210           0 :           positive_recogniser_equation_seen.size() + negative_recogniser_equation_seen.size() == ssf.constructors[domain].size())
     211             :       {
     212           0 :         data::function_symbol constructor = *positive_recogniser_equation_seen.begin();
     213           0 :         ssf.recogniser_func[constructor] = mapping;
     214           0 :       }
     215           0 :     }
     216             : 
     217             :     // Check for projections.
     218           0 :     if (equations.size() == 1)
     219             :     {
     220           0 :       data::data_equation equation = equations[0];
     221           0 :       if (equation.condition() == data::sort_bool::true_() &&
     222           0 :           data::is_variable(equation.rhs()) &&
     223           0 :           data::is_application(equation.lhs()))
     224             :       {
     225           0 :         data::application application(equation.lhs());
     226           0 :         assert(application.head() == mapping);
     227           0 :         assert(application.size() == 1);
     228           0 :         data::data_expression argument(application[0]);
     229           0 :         if (data::is_application(argument) &&
     230           0 :             data::is_function_symbol(data::application(argument).head()) &&
     231           0 :             ssf.constructors[domain].count(data::function_symbol(data::application(argument).head())) == 1)
     232             :         {
     233           0 :           data::application constructor_application(argument);
     234           0 :           data::function_symbol constructor(constructor_application.head());
     235             : 
     236           0 :           bool all_args_are_vars = std::all_of(constructor_application.begin(), constructor_application.end(), &data::is_variable);
     237           0 :           bool all_vars_are_unique = data::find_all_variables(constructor_application).size() == constructor_application.size();
     238           0 :           auto find_result = std::find(constructor_application.begin(), constructor_application.end(), equation.rhs());
     239             : 
     240           0 :           if (find_result != constructor_application.end() && all_args_are_vars && all_vars_are_unique)
     241             :           {
     242           0 :             data::application::const_iterator::difference_type index = find_result - constructor_application.begin();
     243           0 :             assert(index >= 0 && index < static_cast<data::application::const_iterator::difference_type>(constructor_application.size()));
     244           0 :             ssf.projection_func[constructor].resize(constructor_application.size());
     245           0 :             ssf.projection_func[constructor][index] = mapping;
     246             :           }
     247           0 :         }
     248           0 :       }
     249           0 :     }
     250             :   }
     251             : 
     252           0 :   return std::make_pair(ssf, rewrite_rules);
     253           0 : }
     254             : 
     255             : /**
     256             :  * \brief Complete the containers with recognisers and projections in ssf
     257             :  * \details Also sets native translations and build a set of all recognisers and
     258             :  * projections in dataspec.
     259             :  */
     260           0 : std::set<data::function_symbol> complete_recognisers_projections(const data::data_specification& dataspec, native_translations& nt, structured_sort_functions& ssf)
     261             : {
     262           0 :   std::set<data::function_symbol> recog_and_proj;
     263             : 
     264           0 :   for(const data::function_symbol& cons: dataspec.constructors())
     265             :   {
     266           0 :     auto find_result = ssf.recogniser_func.find(cons);
     267           0 :     if(find_result != ssf.recogniser_func.end())
     268             :     {
     269           0 :       nt.set_native_definition(find_result->second, make_recogniser_name(cons, nt));
     270           0 :       recog_and_proj.insert(find_result->second);
     271             :     }
     272             :     else
     273             :     {
     274           0 :       ssf.recogniser_func[cons] = make_recogniser_func(cons, nt);
     275             :     }
     276             : 
     277           0 :     if(data::is_function_sort(cons.sort()))
     278             :     {
     279           0 :       std::size_t index = 0;
     280           0 :       const data::sort_expression_list& arg_list = atermpp::down_cast<data::function_sort>(cons.sort()).domain();
     281           0 :       ssf.projection_func[cons].resize(arg_list.size());
     282           0 :       for(const data::sort_expression& arg: arg_list)
     283             :       {
     284           0 :         data::function_symbol& projection = ssf.projection_func[cons][index];
     285           0 :         if(projection != data::function_symbol())
     286             :         {
     287           0 :           nt.set_native_definition(projection, make_projection_name(cons, index, nt));
     288           0 :           recog_and_proj.insert(projection);
     289             :         }
     290             :         else
     291             :         {
     292           0 :           projection = make_projection_func(cons, arg, index, nt);
     293             :         }
     294           0 :         index++;
     295             :       }
     296             :     }
     297             :   }
     298             : 
     299           0 :   return recog_and_proj;
     300           0 : }
     301             : 
     302             : inline
     303           0 : void unfold_pattern_matching(const data::data_specification& dataspec, native_translations& nt)
     304             : {
     305           0 :   std::set<core::identifier_string> used_ids = data::find_identifiers(dataspec);
     306           0 :   auto p = find_structured_sort_functions(dataspec, [&nt](const data::function_symbol& f){ return nt.has_native_definition(f); });
     307           0 :   structured_sort_functions& ssf = p.first;
     308           0 :   std::map<data::function_symbol, data::data_equation_vector>& rewrite_rules = p.second;
     309             : 
     310           0 :   std::set<data::function_symbol> recog_and_proj = complete_recognisers_projections(dataspec, nt, ssf);
     311             : 
     312           0 :   data::representative_generator rep_gen(dataspec);
     313           0 :   for(const auto& [function, rewr_equations]: rewrite_rules)
     314             :   {
     315             :     // Only unfold equations with parameters
     316             :     // Do not unfold recognisers and projection functions
     317             :     // Only unfold equations that satisfy the function 'is_pattern_matching_rule'
     318           0 :     if (data::is_function_sort(function.sort()) &&
     319           0 :         recog_and_proj.find(function) == recog_and_proj.end() &&
     320           0 :         std::all_of(rewr_equations.begin(),
     321             :                     rewr_equations.end(),
     322           0 :                     [&ssf](const data::data_equation& eqn){ return data::is_pattern_matching_rule(ssf, eqn); }))
     323             :     {
     324           0 :       data::set_identifier_generator id_gen;
     325           0 :       id_gen.add_identifiers(used_ids);
     326           0 :       data::data_equation unfolded_eqn = data::unfold_pattern_matching(function, rewr_equations, ssf, rep_gen, id_gen);
     327           0 :       nt.set_native_definition(function, unfolded_eqn);
     328           0 :     }
     329             :   }
     330           0 : }
     331             : 
     332             : } // namespace smt
     333             : } // namespace mcrl2
     334             : 
     335             : #endif

Generated by: LCOV version 1.14