Changeset 6243


Ignore:
Timestamp:
Apr 16, 2014 4:55:44 PM (6 years ago)
Author:
nlandin
Message:

Fixed time variable not being initialized correctly when importing optimization problems. #3598

Location:
branches/NewXMLExport
Files:
5 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/NewXMLExport/ModelicaCasADiInterface/src/Model.cpp

    r6149 r6243  
    294294Ref<Variable> Model::getVariable(std::string name) {
    295295    Ref<Variable> returnVar = Ref<Variable>(NULL);
    296     returnVar = modelVariableMap.find(name)->second;
     296    std::map<std::string, Variable*>::iterator it = modelVariableMap.find(name);
     297    if (it != modelVariableMap.end()) {
     298        returnVar = it->second;;
     299    } else {
     300        return NULL;
     301    }
     302    //returnVar = modelVariableMap.find(name)->second;
    297303    /*for (vector< Variable * >::iterator it = z.begin(); it != z.end(); ++it) {
    298304        if ((*it)->getName() == name) {
  • branches/NewXMLExport/ModelicaCasADiInterface/src/Model.hpp

    r6149 r6243  
    3838        typedef std::map< std::string, Ref<ModelFunction> > functionMap;
    3939        typedef std::map< std::string, Ref<VariableType> > typeMap;
     40    protected:
    4041        typedef std::map<std::string, Variable*> variableMap;
    4142    public:
     
    239240        void assignVariableTypeToVariable(Ref<Variable> var);
    240241       
     242    protected:
    241243        variableMap modelVariableMap;
    242244};
  • branches/NewXMLExport/ModelicaCasADiInterface/src/transferXML.cpp

    r6216 r6243  
    227227/**
    228228 * Construct an MXFunction from the XML and adds it to the model
     229 * TODO: Split up in several smaller parts, one for handling assignments and one for functioncalls
    229230 */
    230231void transferFunction(Ref<Model> m, XMLElement* elem) {
     
    241242                }
    242243                if (!strcmp(stmt->Value(), "assign")) {
    243                     MXVector lhs = MXVector();
    244244                    XMLElement* left = stmt->FirstChildElement()->FirstChildElement();
    245245                    XMLElement* checkRight = stmt->FirstChildElement()->NextSiblingElement()->FirstChildElement();
    246                     if (!strcmp(checkRight->Value(), "call") && !(checkRight->Attribute("builtin") != NULL &&
     246                    // update function variables to reflect function call
     247                    if (left != NULL && !strcmp(checkRight->Value(), "call") && !(checkRight->Attribute("builtin") != NULL &&
    247248                        strcmp(checkRight->Attribute("builtin"), "array"))) {
    248                         if (!strcmp(left->Value(), "tuple")) {
    249                             for (XMLElement* tupleChild = left->FirstChildElement(); tupleChild != NULL; tupleChild = tupleChild->NextSiblingElement()) {
    250                                 if (!strcmp(tupleChild->Value(), "nothing")) {
    251                                     lhs.push_back(MX());
    252                                 } else {
    253                                     MX leftCas;
    254                                     int index = findIndex(vars, tupleChild->Attribute("name"));
    255                                     if (index != -1) {
    256                                         leftCas = vars.at(index);
    257                                     } else {
    258                                         leftCas = MX(tupleChild->Attribute("name"));
    259                                     }
    260                                     lhs.push_back(leftCas);
    261                                 }
    262                             }
    263                         } else if (!strcmp(left->Value(), "local")) {
    264                             string varName = left->Attribute("name");
    265                             for (XMLElement* leftChild = left->FirstChildElement(); leftChild != NULL; leftChild = leftChild->NextSiblingElement()) {
    266                                 // add members to name
    267                                 if (!strcmp(leftChild->Value(), "member")) {
    268                                     varName += ".";
    269                                     varName += leftChild->Attribute("name");
    270                                 }
    271                             }
    272                             int index = findIndex(vars, varName);
    273                             if (index != -1) {
    274                                 lhs.push_back(vars.at(index));
    275                             } else {
    276                                 // check if this local variable is an array by looking it up in the dimensions map
    277                                 // if it is we need to handle it by constructing the scalar variable names
    278                                 // and lookup them in the vars so that the correct variables are swapped
    279                                 if (dimensionMap.count((functionName + varName)) != 0) {
    280                                     std::vector<string> arrayVars = getArrayVariables(left, functionName);
    281                                     for (int i=0; i < arrayVars.size(); i++) {
    282                                         int index = findIndex(vars, arrayVars.at(i));
    283                                         lhs.push_back(vars.at(index));
    284                                     }
    285                                 } else {
    286                                     lhs.push_back(MX(varName));
    287                                 }
    288                             }
    289                         }
    290                         XMLElement* right = stmt->FirstChildElement()->NextSiblingElement()->FirstChildElement()->FirstChildElement();
    291                         string funcName = right->FirstChildElement()->Attribute("name");
    292                         CasADi::MXFunction f = m->getModelFunction(funcName)->getMx();
    293                         MXVector argVec = MXVector();
    294                         for (XMLElement* arg = right->NextSiblingElement(); arg != NULL; arg = arg->NextSiblingElement()) {
    295                             if (!strcmp(arg->Value(), "call")) {
    296                                 if (arg->Attribute("builtin") != NULL && !strcmp(arg->Attribute("builtin"), "array")) {
    297                                     // array constructor
    298                                     for (XMLElement* arr = arg->FirstChildElement(); arr != NULL; arr = arr->NextSiblingElement()) {
    299                                         MX arrCall = expressionToMx(m, arr);
    300                                         for (int i=0; i < arrCall.size(); i++) {
    301                                             argVec.push_back(arrCall.at(i));
    302                                         }
    303                                     }
    304                                 } else if (arg->Attribute("builtin") != NULL) {
    305                                     // builtin function
    306                                     argVec.push_back(expressionToMx(m, arg));
    307                                 } else {
    308                                     // regular function call
    309                                     MX func = expressionToMx(m, arg);
    310                                     for (int i=0; i < func.size(); i++) {
    311                                         argVec.push_back(func.at(i));
    312                                     }
    313                                 }
    314                             } else {
    315                                 // check if array var
    316                                 if (arg->Attribute("name") != NULL && dimensionMap.count((functionName + arg->Attribute("name"))) != 0) {
    317                                     std::vector<string> arrayVars = getArrayVariables(arg, functionName);
    318                                     for (int i=0; i < arrayVars.size(); i++) {
    319                                         MX arg = funcVars.find(arrayVars.at(i))->second->getVar();
    320                                         argVec.push_back(arg);
    321                                     }
    322                                 } else {
    323                                     argVec.push_back(expressionToMx(m, arg));
    324                                 }
    325                             }
    326                         }
    327                         MXVector outputs = f.call(argVec);
    328                         MXVector updatedVec = CasADi::substitute(outputs, vars, expressions);
    329                         for (int i=0; i < lhs.size(); i++) {
    330                             if (!lhs.at(i).isNull()) {
    331                                 int index = findIndex(vars, lhs.at(i).getName());
    332                                 expressions.at(index) = updatedVec.at(i);
    333                             }
    334                         }
    335                     } else {
     249                            updateFunctionCall(m, stmt, expressions, vars, functionName);
     250                    } else if (left != NULL) {
     251                        MXVector lhs = MXVector();
    336252                        MX leftCas;
    337253                        if (!strcmp(left->Value(), "reference")) {
    338                             int flatIndex = calculateFlatArrayIndex(left, functionName);
     254                            int flatIndex = calculateFlatArrayIndex(m, left, functionName);
    339255                            string varName(left->FirstChildElement()->Attribute("name"));
    340256                            std::stringstream ss;
     
    405321    f.init();
    406322    m->setModelFunctionByItsName(new ModelicaCasADi::ModelFunction(f));
     323}
     324
     325/**
     326 * Handles the updating of function calls in functions. The expression vector which
     327 * contains the current MX for all variables in the functions is updated by
     328 * running the function call and then substitute in the outputs.
     329 */
     330void updateFunctionCall(Ref<Model> m, XMLElement* stmt, MXVector &expressions, MXVector &vars, string functionName) {
     331    XMLElement* left = stmt->FirstChildElement()->FirstChildElement();
     332    MXVector lhs = MXVector();
     333    if (!strcmp(left->Value(), "tuple")) {
     334        for (XMLElement* tupleChild = left->FirstChildElement(); tupleChild != NULL; tupleChild = tupleChild->NextSiblingElement()) {
     335            if (!strcmp(tupleChild->Value(), "nothing")) {
     336                // add an emtpy mx if there are an empty spot in the tuple
     337                lhs.push_back(MX());
     338            } else {
     339                MX leftCas;
     340                int index = findIndex(vars, tupleChild->Attribute("name"));
     341                if (index != -1) {
     342                    leftCas = vars.at(index);
     343                } else {
     344                    leftCas = MX(tupleChild->Attribute("name"));
     345                }
     346                lhs.push_back(leftCas);
     347            }
     348        }
     349    } else if (!strcmp(left->Value(), "local")) {
     350        string varName = left->Attribute("name");
     351        for (XMLElement* leftChild = left->FirstChildElement(); leftChild != NULL; leftChild = leftChild->NextSiblingElement()) {
     352            // add members to name
     353            if (!strcmp(leftChild->Value(), "member")) {
     354                varName += ".";
     355                varName += leftChild->Attribute("name");
     356            }
     357        }
     358        int index = findIndex(vars, varName);
     359        if (index != -1) {
     360            lhs.push_back(vars.at(index));
     361        } else {
     362            // check if this local variable is an array by looking it up in the dimensions map
     363            // if it is we need to handle it by constructing the scalar variable names
     364            // and lookup them in the vars so that the correct variables are swapped
     365            if (dimensionMap.count((functionName + varName)) != 0) {
     366                std::vector<string> arrayVars = getArrayVariables(left, functionName);
     367                for (int i=0; i < arrayVars.size(); i++) {
     368                    int index = findIndex(vars, arrayVars.at(i));
     369                    lhs.push_back(vars.at(index));
     370                }
     371            } else {
     372                lhs.push_back(MX(varName));
     373            }
     374        }
     375    }
     376    XMLElement* right = stmt->FirstChildElement()->NextSiblingElement()->FirstChildElement()->FirstChildElement();
     377    string funcName = right->FirstChildElement()->Attribute("name");
     378    CasADi::MXFunction f = m->getModelFunction(funcName)->getMx();
     379    MXVector argVec = MXVector();
     380    for (XMLElement* arg = right->NextSiblingElement(); arg != NULL; arg = arg->NextSiblingElement()) {
     381        if (!strcmp(arg->Value(), "call")) {
     382            if (arg->Attribute("builtin") != NULL && !strcmp(arg->Attribute("builtin"), "array")) {
     383                // array constructor
     384                for (XMLElement* arr = arg->FirstChildElement(); arr != NULL; arr = arr->NextSiblingElement()) {
     385                    MX arrCall = expressionToMx(m, arr);
     386                    for (int i=0; i < arrCall.size(); i++) {
     387                        argVec.push_back(arrCall.at(i));
     388                    }
     389                }
     390            } else if (arg->Attribute("builtin") != NULL) {
     391                // builtin function
     392                argVec.push_back(expressionToMx(m, arg));
     393            } else {
     394                // regular function call
     395                MX func = expressionToMx(m, arg);
     396                for (int i=0; i < func.size(); i++) {
     397                    argVec.push_back(func.at(i));
     398                }
     399            }
     400        } else {
     401            // check if array var
     402            if (arg->Attribute("name") != NULL && dimensionMap.count((functionName + arg->Attribute("name"))) != 0) {
     403                std::vector<string> arrayVars = getArrayVariables(arg, functionName);
     404                for (int i=0; i < arrayVars.size(); i++) {
     405                    MX arg = funcVars.find(arrayVars.at(i))->second->getVar();
     406                    argVec.push_back(arg);
     407                }
     408            } else {
     409                argVec.push_back(expressionToMx(m, arg));
     410            }
     411        }
     412    }
     413    MXVector outputs = f.call(argVec);
     414    MXVector updatedVec = CasADi::substitute(outputs, vars, expressions);
     415    for (int i=0; i < lhs.size(); i++) {
     416        if (!lhs.at(i).isNull()) {
     417            int index = findIndex(vars, lhs.at(i).getName());
     418            expressions.at(index) = updatedVec.at(i);
     419        }
     420    }
    407421}
    408422
     
    633647        return MX(1);
    634648    } else if (!strcmp(name, "false")) {
    635         return CasADi::MX(0);
     649        return MX(0);
    636650    } else if (!strcmp(name, "local")) {
    637651        string varName = expression->Attribute("name");
     
    645659            return funcVars.find(varName)->second->getVar();
    646660        } else if (m->getVariable(varName) != NULL) {
    647             std::cout << m->getVariable(varName) << std::endl;
    648661            return m->getVariable(varName)->getVar();
    649662        }
    650         std::cout << "new mx: " << m->getVariable(varName) << std::endl;
    651663        return MX(varName);
    652664    } else if (!strcmp(name, "call")) {
     
    781793        return MX(preCall);
    782794    } else if (!strcmp(op->Attribute("name"), "assert")) {
     795        return MX(0);
    783796        // ignore asserts
    784797    } else if (!strcmp(op->Attribute("name"), "time")) {
     
    795808}
    796809
     810/**
     811 * Takes a reference tag and converts it to a MX expression
     812 */
    797813MX referenceToMx(Ref<Model> m, XMLElement* ref) {
    798814    // get the name of the function since this is needed to look up array dimensions
     
    804820    }
    805821    XMLElement* varName = ref->FirstChildElement();
    806     int flatIndex = calculateFlatArrayIndex(ref, functionName);
     822    int flatIndex = calculateFlatArrayIndex(m, ref, functionName);
    807823    string var (varName->Attribute("name"));
    808824    std::stringstream ss;
     
    815831}
    816832
     833/**
     834 * Convert if expression to MX
     835 */
    817836MX ifExpToMx(Ref<Model> m, XMLElement* expression) {
    818837    XMLElement* branching = expression->FirstChildElement();
     
    10561075 * calculate a flat index from these dimensions
    10571076 */
    1058 int calculateFlatArrayIndex(XMLElement* reference, string functionName) {
     1077int calculateFlatArrayIndex(Ref<Model> m, XMLElement* reference, string functionName) {
    10591078    XMLElement* varName = reference->FirstChildElement();
    10601079    std::vector<int> dimensions = dimensionMap.find((functionName + varName->Attribute("name")))->second;
    10611080    std::vector<int> subscripts;
    10621081    for (XMLElement* sub = varName->NextSiblingElement(); sub != NULL; sub = sub->NextSiblingElement()) {
    1063         if (strcmp(sub->Value(), "subscripts")) {
    1064             break;
    1065         }
    1066         subscripts.push_back(atoi(sub->FirstChildElement()->Attribute("value"))-1);
     1082        if (!strcmp(sub->FirstChildElement()->Value(), "call")) {
     1083            MX tmp = expressionToMx(m, sub->FirstChildElement());
     1084            subscripts.push_back(tmp.getValue()-1);
     1085        } else if (!strcmp(sub->FirstChildElement()->Value(), "integer") || !strcmp(sub->FirstChildElement()->Value(), "real")) {
     1086            subscripts.push_back(atoi(sub->FirstChildElement()->Attribute("value"))-1);
     1087        } else {
     1088            throw std::runtime_error("Only integer expressions and constants are supported as array indices");
     1089        }
    10671090    }
    10681091    // convert subscripts to flat index
  • branches/NewXMLExport/ModelicaCasADiInterface/src/transferXML.hpp

    r6194 r6243  
    5656void transferEquations(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement* elem);
    5757void transferParameters(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement* elem);
     58
    5859void transferFunction(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement* elem);
    59 
     60void updateFunctionCall(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement* stmt,
     61    CasADi::MXVector &expressions, CasADi::MXVector &vars, std::string functionName);
    6062CasADi::MXVector getInputVector(ModelicaCasADi::Ref<ModelicaCasADi::Model>, tinyxml2::XMLElement* elem);
    6163CasADi::MXVector getFuncVars(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement *elem);
     
    8587void addFunctionHeaders(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement* elem);
    8688
    87 int calculateFlatArrayIndex(tinyxml2::XMLElement* reference, std::string functionName);
     89int calculateFlatArrayIndex(ModelicaCasADi::Ref<ModelicaCasADi::Model> m, tinyxml2::XMLElement* reference, std::string functionName);
    8890std::vector<std::string> getArrayVariables(tinyxml2::XMLElement* elem, std::string functionName);
    8991
  • branches/NewXMLExport/ModelicaCasADiInterface/src/transferXMLOptimization.cpp

    r6216 r6243  
    3232    const std::vector<string> &modelFiles) {
    3333        // transfer model parts first
    34         //transferXmlModel(optProblem, modelName, modelFiles);
    35 
     34        transferXmlModel(optProblem, modelName, modelFiles);
    3635        optProblem->initializeProblem(modelName, true);
     36        optProblem->setTimeVariable(MX("time"));
     37
    3738        string fullPath;
    3839        for (int i=0; i < modelFiles.size(); i++) {
     
    4950        bool lagrangeSet = false;
    5051        bool mayerSet = false;
    51         for(XMLElement* rootChild = root->FirstChildElement(); rootChild != NULL; rootChild = rootChild->NextSiblingElement()) {
    52             if (!strcmp(rootChild->Value(), "component") || !strcmp(rootChild->Value(), "classDefinition")) {
     52        for (XMLElement* rootChild = root->FirstChildElement(); rootChild != NULL; rootChild = rootChild->NextSiblingElement()) {
     53            /*if (!strcmp(rootChild->Value(), "component") || !strcmp(rootChild->Value(), "classDefinition")) {
    5354                transferVariables(optProblem, rootChild);
    5455            } else if (!strcmp(rootChild->Value(), "equation")) {
     
    6566                    transferEquations(optProblem, rootChild);
    6667                }
    67             } else if (!strcmp(rootChild->Value(), "objective")) {
     68            } else*/
     69            if (!strcmp(rootChild->Value(), "objective")) {
    6870                mayerSet = true;
    6971                transferObjective(optProblem, rootChild);
     
    9496
    9597void transferObjectiveIntegrand(Ref<OptimizationProblem> optProblem, XMLElement* objectiveIntegrand) {
    96     std::cout << "integrand" << std::endl;
    9798    optProblem->setLagrangeTerm(expressionToMx(optProblem, objectiveIntegrand->FirstChildElement()));
    98     std::cout << "end integrand" << std::endl;
    9999}
    100100
Note: See TracChangeset for help on using the changeset viewer.